We load the dataset
import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
pd.options.display.max_columns = None
#df = pd.read_csv('S:/MHS_Pathfinder/bipolar_prediction/GOLD_&_AURUM_15_04_20.csv', low_memory=False)
df = pd.read_csv('clean_dataset.csv', low_memory=False)
NB: Because lo_HDL has missing values, pandas interprets the whole column as float by default.
# What columns do we have?
print('\n'.join(sorted(df.columns.to_list())))
BMI BMI_date BP_date CHD CHD_date CKD3 Chronic_pulmonary_disease Chronic_pulmonary_disease_date Coagulopathy Coagulopathy_date Congestive_heart_failure Congestive_heart_failure_date Deficiency_anaemia Deficiency_anaemia_date Diabetes_organ_damage Diabetes_organ_damage_date Diabetes_uncomplicated Diabetes_uncomplicated_date FH_BPD FH_BPD_date FH_LD FH_LD_date FH_NOS FH_NOS_date FH_anxiety FH_anxiety_date FH_any FH_depression FH_depression_date FH_psychosis FH_psychosis_date FH_substance FH_substance_date FH_suicide FH_suicide_date Fluid_electrolyte_disorder_date Fluid_electrolyte_disorders HDL HDL_date HIV_AIDS HIV_AIDS_date LDL LDL_date Liver_disease Liver_disease_date N_dep_b4 N_man_b4 Neurological_disorders Neurological_disorders_date OCD OCD_date PD PD_date Peptic_ulcer Peptic_ulcer_date Peripheral_vascular Peripheral_vascular_date Pulmonary_circulation Pulmonary_circulation_date RA RA_date SSRI SSRI_b4 SSRI_during T2DM T2DM_date TCA TCA_b4 TCA_during TSH TSH_date Valvular_disease Valvular_disease_date Weight_loss Weight_loss_date adhd adhd_date age_BMI age_BP age_CHD age_Chronic_pulmonary_disease age_Coagulopathy age_Congestive_heart_failure age_Deficiency_anaemia age_Diabetes_organ_damage age_Diabetes_uncomplicated age_FH_BPD age_FH_LD age_FH_NOS age_FH_anxiety age_FH_depression age_FH_psychosis age_FH_substance age_FH_suicide age_Fluid_electrolyte_disorder age_HDL age_HIV_AIDS age_LDL age_Liver_disease age_Neurological_disorders age_OCD age_PD age_Peptic_ulcer age_Peripheral_vascular age_Pulmonary_circulation age_RA age_T2DM age_TSH age_Valvular_disease age_Weight_loss age_adhd age_alcohol age_anxiety age_asthma age_ca age_cannabis age_cardiac_arrythmia age_conduct age_death age_depression age_dermatitis age_diagnosis age_eGFR age_ethnicity age_first age_first_AP age_first_MS age_first_SSRI age_first_TCAs age_first_diagnosis age_first_exposure age_first_li age_first_olan age_first_other_ADs age_first_reg age_hyperthyroid age_hypothyroid age_last_SSRI age_last_TCAs age_last_other_ADs age_mania age_migraine age_other_substance_misuse age_psych_FH age_psychosis age_relationship age_self_harm age_sleep age_smoke age_stress age_transfer_out alcohol alcohol_date anxiety anxiety_date any_AD_b4 any_AD_during ap_b4 ap_duration asthma asthma_date ca ca_date cannabis cannabis_date cardiac_arrythmia cardiac_arrythmia_date cohort_end cohort_start conduct conduct_date death_date depression depression_date dermatitis dermatitis_date diagnosis_date diastolic dob dominant eGFR_date end_reason ethnicity ethnicity_date ex_time exposure exposure_end exposure_start first_AP_date first_MS_date first_SSRI_date first_TCAs_date first_date first_episode first_li_date first_olan_date first_other_ADs_date first_reg_date hi_LDL hi_ca hypertension hyperthyroid hyperthyroid_date hypothyroid hypothyroid_combined hypothyroid_date incident_script last_SSRI_date last_TCAs_date last_other_ADs_date li_b4 lo_HDL lo_ca mania mania_date mania_type migraine migraine_date ms_b4 ms_duration olan_b4 other_AD_b4 other_AD_during other_ADs other_substance_misuse other_substance_misuse_date patid pracid psych_FH_date psychosis psychosis_date relationship relationship_date responder2 response2_1 self_harm self_harm_date sex sleep sleep_date smoke_date smoker source stress stress_date suitable symptom_to_diagnosis symptom_to_exposure systolic thyroid_blood transfer_out_date weight year_exposure yob
# How many unique patients?
print("Total patients:", df.patid.nunique())
print(df['suitable'].value_counts(dropna=False))
Total patients: 38957 1 31518 0 7439 Name: suitable, dtype: int64
Now we only keep patients that are suitable for the inclusion analysis (suitable==1). Those are defined as patients with >2 years of follow-up after exposure_start.
df_old = df.copy()
df = df.loc[df.suitable==1]
print("New total patients:", df.patid.nunique())
New total patients: 31518
# Count values for a few important columns
for item in ['source', 'exposure', 'suitable', 'response2_1', 'responder2', 'symptom_to_exposure', 'exposure_end']:
print(item)
print(df[item].value_counts(dropna=False))
print(df[item].describe())
print()
source
AURUM 20910
GOLD 10608
Name: source, dtype: int64
count 31518
unique 2
top AURUM
freq 20910
Name: source, dtype: object
exposure
lithium 19106
olanzapine 12412
Name: exposure, dtype: int64
count 31518
unique 2
top lithium
freq 19106
Name: exposure, dtype: object
suitable
1 31518
Name: suitable, dtype: int64
count 31518.0
mean 1.0
std 0.0
min 1.0
25% 1.0
50% 1.0
75% 1.0
max 1.0
Name: suitable, dtype: float64
response2_1
0.0 14785
1.0 11848
NaN 4885
Name: response2_1, dtype: int64
count 26633.000000
mean 0.444862
std 0.496960
min 0.000000
25% 0.000000
50% 0.000000
75% 1.000000
max 1.000000
Name: response2_1, dtype: float64
responder2
0.0 19670
1.0 11848
Name: responder2, dtype: int64
count 31518.000000
mean 0.375912
std 0.484365
min 0.000000
25% 0.000000
50% 0.000000
75% 1.000000
max 1.000000
Name: responder2, dtype: float64
symptom_to_exposure
0.000000 3251
0.038330 46
0.002738 45
0.019165 42
0.093087 38
...
4.098563 1
23.208761 1
20.591375 1
-15.641341 1
14.631075 1
Name: symptom_to_exposure, Length: 11420, dtype: int64
count 31518.000000
mean 10.186286
std 12.650202
min -30.031485
25% 0.813142
50% 6.557153
75% 15.629706
max 148.963730
Name: symptom_to_exposure, dtype: float64
exposure_end
04mar2019 71
01mar2019 65
26feb2019 64
19feb2019 59
18feb2019 58
..
30mar2013 1
19jun1989 1
24sep2007 1
24jan2014 1
09dec1999 1
Name: exposure_end, Length: 7796, dtype: int64
count 31518
unique 7796
top 04mar2019
freq 71
Name: exposure_end, dtype: object
features13 = ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm']
features = """Total, N
Age at diagnosis, median (IQR)
Age at medication initiation, median (IQR)
Years between diagnosis and exposure, median (IQR)
Female, n (%)
First presentation mania, n (%)
First presentation depression, n (%)
Depression dominant, n (%)
Psychotic experiences, n (%)
Self-harm history, n (%)
Smoker, n (%)
Family history for bipolar disorder, n (%)
Family history for depression, n (%)
Family history for psychosis, n (%)
Overweight or obese, n (%)"""
featurs_list = features.split('\n')
Table1 = pd.DataFrame(featurs_list, columns=['Features'])
for exp, exp_df in df.groupby('exposure'):
data_list = []
data_list.append(str(len(exp_df)))
for feature in ['age_first_diagnosis', 'age_first_exposure', 'symptom_to_exposure']:
Q3 = np.quantile(exp_df[feature], 0.75)
median = np.quantile(exp_df[feature], 0.5)
median = exp_df[feature].median()
Q1 = np.quantile(exp_df[feature], 0.25)
IQR = Q3 - Q1
data_list.append("{:.2f} ({:.2f})".format(median, IQR))
for feature in ['sex', 'mania', 'depression', 'dominant', 'psychosis', 'self_harm', 'smoker', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight']:
if feature == 'sex':
sum = (exp_df[feature] == 'female').sum()
elif feature == 'dominant':
sum = (exp_df[feature] == 'depression').sum()
elif feature == 'smoker':
sum = exp_df[feature].isin(['ex smoker', 'current smoker']).sum()
elif feature == 'weight':
sum = exp_df[feature].isin(['overweight', 'obese']).sum()
else:
sum = exp_df[feature].sum()
data_list.append("{} ({:.2f} %)".format(sum,sum/len(exp_df)*100))
Table1[exp] = data_list
display(Table1)
| Features | lithium | olanzapine | |
|---|---|---|---|
| 0 | Total, N | 19106 | 12412 |
| 1 | Age at diagnosis, median (IQR) | 40.82 (24.13) | 39.08 (22.36) |
| 2 | Age at medication initiation, median (IQR) | 46.51 (23.20) | 42.47 (23.04) |
| 3 | Years between diagnosis and exposure, median (... | 7.37 (15.60) | 5.45 (13.35) |
| 4 | Female, n (%) | 11526 (60.33 %) | 6858 (55.25 %) |
| 5 | First presentation mania, n (%) | 4705 (24.63 %) | 3498 (28.18 %) |
| 6 | First presentation depression, n (%) | 11233 (58.79 %) | 7453 (60.05 %) |
| 7 | Depression dominant, n (%) | 9662 (50.57 %) | 6457 (52.02 %) |
| 8 | Psychotic experiences, n (%) | 4939 (25.85 %) | 3806 (30.66 %) |
| 9 | Self-harm history, n (%) | 2362 (12.36 %) | 1822 (14.68 %) |
| 10 | Smoker, n (%) | 13970 (73.12 %) | 8796 (70.87 %) |
| 11 | Family history for bipolar disorder, n (%) | 402 (2.10 %) | 136 (1.10 %) |
| 12 | Family history for depression, n (%) | 375 (1.96 %) | 245 (1.97 %) |
| 13 | Family history for psychosis, n (%) | 111 (0.58 %) | 141 (1.14 %) |
| 14 | Overweight or obese, n (%) | 6540 (34.23 %) | 5038 (40.59 %) |
Some values are negative. They should not be.
display(df.loc[df.symptom_to_exposure < 0, 'symptom_to_exposure'])
display(df.loc[df.age_first_diagnosis < 0, ['age_first_diagnosis', 'symptom_to_exposure']])
pd.set_option('display.max_rows', 100)
display(df.loc[(df.symptom_to_exposure>65) & (df.symptom_to_exposure<80)])
74 -8.372348
79 -4.016427
103 -2.948665
158 -7.860370
176 -0.032854
...
38872 -5.037645
38877 -0.016427
38880 -5.653662
38924 -0.977413
38937 -3.529090
Name: symptom_to_exposure, Length: 1213, dtype: float64
| age_first_diagnosis | symptom_to_exposure | |
|---|---|---|
| 2067 | -77.500343 | 106.620120 |
| 6260 | -68.498291 | 104.038330 |
| 6771 | -0.134155 | 24.533880 |
| 7937 | -87.498970 | 134.294310 |
| 8706 | -98.499657 | 143.857640 |
| 8709 | -109.500340 | 141.295000 |
| 9641 | -46.499657 | 89.848053 |
| 14944 | -58.502396 | 102.907600 |
| 15065 | -49.503078 | 104.700890 |
| 15412 | -90.499657 | 145.867220 |
| 15525 | -117.500340 | 148.963730 |
| 15585 | -99.498970 | 134.179340 |
| 15849 | -56.498287 | 99.531830 |
| 15997 | -73.500343 | 139.493500 |
| 16216 | -0.314853 | 23.898699 |
| 17417 | -96.498291 | 139.028060 |
| 23681 | -0.114990 | 57.355236 |
| 24757 | -22.499659 | 97.952087 |
| 24810 | -32.498287 | 97.932922 |
| 24849 | -35.498974 | 97.092400 |
| 25084 | -0.432580 | 31.414101 |
| 25134 | -30.499659 | 98.453117 |
| 31628 | -50.502396 | 102.009580 |
| 31843 | -29.500341 | 95.737167 |
| 34554 | -60.498287 | 136.251880 |
| 37134 | -0.156057 | 60.238194 |
| patid | pracid | diagnosis_date | sex | yob | first_reg_date | transfer_out_date | death_date | cohort_start | cohort_end | end_reason | exposure | incident_script | dob | exposure_end | exposure_start | suitable | responder2 | response2_1 | adhd_date | adhd | alcohol_date | alcohol | asthma_date | asthma | cannabis_date | cannabis | conduct_date | conduct | dermatitis_date | dermatitis | migraine_date | migraine | other_substance_misuse_date | other_substance_misuse | psychosis_date | psychosis | self_harm_date | self_harm | stress_date | stress | mania_date | mania | mania_type | N_man_b4 | depression_date | depression | N_dep_b4 | symptom_to_exposure | symptom_to_diagnosis | dominant | FH_BPD_date | FH_BPD | FH_psychosis_date | FH_psychosis | FH_depression_date | FH_depression | FH_NOS_date | FH_NOS | FH_anxiety_date | FH_anxiety | FH_suicide_date | FH_suicide | FH_LD_date | FH_LD | FH_substance_date | FH_substance | FH_any | anxiety_date | anxiety | PD_date | PD | sleep_date | sleep | T2DM_date | T2DM | BMI_date | BMI | weight | ethnicity_date | year_exposure | ex_time | smoke_date | CHD_date | CHD | relationship | relationship_date | diastolic | BP_date | systolic | hypertension | eGFR_date | CKD3 | LDL | LDL_date | hi_LDL | HDL | HDL_date | lo_HDL | TSH | TSH_date | thyroid_blood | hypothyroid_date | hypothyroid | hypothyroid_combined | ca | ca_date | hi_ca | lo_ca | source | first_episode | OCD_date | OCD | psych_FH_date | first_date | age_first_exposure | age_first_diagnosis | hyperthyroid_date | hyperthyroid | smoker | cardiac_arrythmia_date | cardiac_arrythmia | Neurological_disorders_date | Neurological_disorders | Liver_disease_date | Liver_disease | HIV_AIDS_date | HIV_AIDS | Fluid_electrolyte_disorder_date | Fluid_electrolyte_disorders | Diabetes_uncomplicated_date | Diabetes_uncomplicated | Diabetes_organ_damage_date | Diabetes_organ_damage | Deficiency_anaemia_date | Deficiency_anaemia | Congestive_heart_failure_date | Congestive_heart_failure | Coagulopathy_date | Coagulopathy | Chronic_pulmonary_disease_date | Chronic_pulmonary_disease | Weight_loss_date | Weight_loss | Valvular_disease_date | Valvular_disease | RA_date | RA | Pulmonary_circulation_date | Peripheral_vascular_date | Peripheral_vascular | Peptic_ulcer_date | Peptic_ulcer | first_AP_date | first_MS_date | first_li_date | first_olan_date | ap_b4 | ap_duration | ms_b4 | ms_duration | li_b4 | olan_b4 | SSRI | first_SSRI_date | last_SSRI_date | SSRI_b4 | SSRI_during | TCA | first_TCAs_date | last_TCAs_date | TCA_b4 | TCA_during | other_ADs | first_other_ADs_date | last_other_ADs_date | other_AD_b4 | other_AD_during | any_AD_b4 | any_AD_during | Pulmonary_circulation | ethnicity | age_diagnosis | age_first_reg | age_transfer_out | age_death | age_adhd | age_alcohol | age_asthma | age_cannabis | age_conduct | age_dermatitis | age_migraine | age_other_substance_misuse | age_psychosis | age_self_harm | age_stress | age_mania | age_depression | age_FH_BPD | age_FH_psychosis | age_FH_depression | age_FH_NOS | age_FH_anxiety | age_FH_suicide | age_FH_LD | age_FH_substance | age_anxiety | age_PD | age_sleep | age_T2DM | age_BMI | age_ethnicity | age_smoke | age_CHD | age_relationship | age_BP | age_eGFR | age_LDL | age_HDL | age_TSH | age_hypothyroid | age_ca | age_OCD | age_psych_FH | age_first | age_hyperthyroid | age_cardiac_arrythmia | age_Neurological_disorders | age_Liver_disease | age_HIV_AIDS | age_Fluid_electrolyte_disorder | age_Diabetes_uncomplicated | age_Diabetes_organ_damage | age_Deficiency_anaemia | age_Congestive_heart_failure | age_Coagulopathy | age_Chronic_pulmonary_disease | age_Weight_loss | age_Valvular_disease | age_RA | age_Pulmonary_circulation | age_Peripheral_vascular | age_Peptic_ulcer | age_first_AP | age_first_MS | age_first_li | age_first_olan | age_first_SSRI | age_last_SSRI | age_first_TCAs | age_last_TCAs | age_first_other_ADs | age_last_other_ADs | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5883 | 13787565 | 565 | 2001-11-14 | male | 1933-01-01 | 2001-02-05 | NaN | NaN | 05feb2001 | 19nov2014 | end f/u | lithium | 0 | 03jul1933 | 16sep2003 | 30aug2002 | 1 | 0.0 | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 2004-09-24 | 0 | NaN | 0 | NaN | 0 | 2001-11-14 | 1 | 1993-10-15 | 1 | NaN | 0 | 2001-11-14 | 1 | mania+psychoses | 1 | 1933-01-28 | 1 | 10 | 69.585213 | 68.793976 | depression | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | 1933-01-28 | 1 | 1983-04-06 | 1.0 | 2001-09-22 | 1 | NaN | 0 | 2001-02-09 | 22.737589 | healthy weight | NaN | NaN | 1.045859 | 2003-10-22 | 2008-11-20 | 0 | 0 | NaN | 82.0 | 2002-05-22 | 150.0 | 1 | 2010-04-16 | 0 | NaN | NaN | 0 | NaN | NaN | 0.0 | 0.57 | 2001-11-05 | NaN | NaN | 0 | NaN | NaN | NaN | 0 | 0 | GOLD | depression | NaN | 0 | NaN | 1933-01-28 | 69.158112 | 68.366875 | NaN | 0 | never smoker | NaN | 0 | 2006-02-27 | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | 0 | NaN | 0 | 2001-09-28 | NaN | 2001-11-16 | 2003-09-16 | 1 | 0.919918 | 0 | NaN | 1 | 0 | Fluoxetine hydrochloride | 2004-10-27 | 2007-08-06 | 0 | 0 | NaN | NaN | NaN | 0 | 0 | Mirtazapine | 2001-02-16 | 2004-10-21 | 0 | 1 | 0.0 | 1.0 | 0 | White | 68.915068 | 68.142466 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 71.778082 | NaN | NaN | 68.915068 | 60.827397 | NaN | 68.915068 | 0.073973 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0.073973 | 50.293151 | 68.769863 | NaN | 68.153425 | NaN | 70.852055 | 75.936986 | NaN | 69.432877 | 77.339726 | NaN | NaN | 68.890411 | NaN | NaN | NaN | NaN | 0.073973 | NaN | NaN | 73.205479 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 68.786301 | NaN | 68.920548 | 70.753425 | 71.868493 | 74.643836 | NaN | NaN | 68.172603 | 71.852055 |
| 25718 | 2842029 | 29 | 1965-01-01 | female | 1911-01-01 | 2002-01-29 | 2004-03-16 | 2004-02-14 | 29jan2002 | 16mar2004 | died | lithium | 0 | 03jul1911 | 09mar2004 | 14feb2002 | 1 | 1.0 | 1.0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 1965-01-01 | 1 | NaN | 0 | NaN | 0 | NaN | 0 | unclear | 0 | 1930-01-01 | 1 | 2 | 72.120468 | 35.000683 | depression | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | NaN | healthy weight | NaN | NaN | 2.064339 | NaN | NaN | 0 | 0 | NaN | 80.0 | 2003-04-19 | 128.0 | 0 | NaN | 0 | NaN | NaN | 0 | NaN | NaN | 0.0 | NaN | NaN | NaN | NaN | 0 | NaN | NaN | NaN | 0 | 0 | GOLD | depression | NaN | 0 | NaN | 1930-01-01 | 90.620125 | 53.500343 | NaN | 0 | never smoker | NaN | 0 | 1998-01-01 | 1 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | 0 | NaN | 0 | NaN | NaN | 2002-02-14 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | Sertraline hydrochloride | 2002-02-14 | 2003-02-06 | 0 | 1 | NaN | NaN | NaN | 0 | 0 | Mirtazapine | 2003-08-19 | 2003-11-28 | 0 | 0 | 0.0 | 1.0 | 0 | White | 54.038356 | 91.139726 | 93.268493 | 93.183562 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 54.038356 | NaN | NaN | NaN | 19.013699 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 92.358904 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 19.013699 | NaN | NaN | 87.060274 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 91.183562 | NaN | 91.183562 | 92.161644 | NaN | NaN | 92.693151 | 92.969863 |
| 30300 | 4704742 | 742 | 1928-01-01 | female | 1917-01-01 | 1980-10-29 | 2006-12-01 | 2006-12-01 | 01jan1987 | 01dec2006 | died | lithium | 1 | 03jul1917 | 05aug1997 | 23oct1996 | 1 | 0.0 | 0.0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 1976-06-01 | 1 | 1981-06-06 | 1 | NaN | 0 | 1976-06-01 | 1 | mania+psychoses | 1 | NaN | 0 | 0 | 68.810402 | 0.000000 | mania | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | NaN | 0 | NaN | NaN | 1993-06-01 | 1 | NaN | 0 | NaN | NaN | healthy weight | NaN | NaN | 0.783025 | 2004-11-11 | 2002-12-10 | 0 | 1 | NaN | 80.0 | 1999-12-07 | 140.0 | 0 | 2003-08-21 | 0 | NaN | NaN | 0 | NaN | NaN | 0.0 | NaN | NaN | NaN | 1974-01-01 | 1 | NaN | NaN | NaN | 0 | 0 | GOLD | mania | NaN | 0 | NaN | 1928-01-01 | 79.307327 | 10.496920 | NaN | 0 | ex smoker | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 2002-12-10 | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 1997-09-01 | 0 | NaN | 0 | NaN | NaN | 0 | NaN | 0 | NaN | NaN | 1996-10-23 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | Sertraline hydrochloride | 1997-11-03 | 2001-12-10 | 0 | 0 | Clomipramine hydrochloride | 2002-05-24 | 2002-07-15 | 0 | 0 | Venlafaxine hydrochloride | 1996-10-15 | 2003-06-27 | 0 | 1 | 0.0 | 1.0 | 0 | White | 11.005479 | 63.868493 | 89.975342 | 89.975342 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 59.454795 | 64.471233 | NaN | 59.454795 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 76.465753 | NaN | NaN | NaN | 87.920548 | 85.997260 | NaN | 82.986301 | 86.693151 | NaN | NaN | NaN | 57.038356 | NaN | NaN | NaN | 11.005479 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 85.99726 | NaN | NaN | NaN | 80.720548 | NaN | NaN | NaN | NaN | NaN | NaN | 79.863014 | NaN | 80.893151 | 84.997260 | 85.449315 | 85.591781 | 79.841096 | 86.542466 |
| 32956 | 6177329 | 329 | 1936-11-24 | male | 1936-01-01 | 1991-01-28 | 2009-03-12 | NaN | 28jan1991 | 12mar2009 | started AP | lithium | 1 | 02jul1936 | 14may2004 | 16apr2004 | 1 | 0.0 | 0.0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 1936-11-24 | 1 | NaN | 0 | NaN | 0 | NaN | 0 | unclear | 0 | 2006-01-23 | 0 | 0 | 67.392197 | 0.000000 | unclear | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | 2002-02-28 | 31.134583 | obese | NaN | NaN | 0.076660 | 2004-03-12 | NaN | 0 | 0 | NaN | 90.0 | 2004-03-12 | 156.0 | 0 | 2009-01-15 | 0 | NaN | NaN | 0 | NaN | NaN | 0.0 | NaN | NaN | NaN | NaN | 0 | NaN | NaN | NaN | 0 | 0 | GOLD | mania | NaN | 0 | NaN | 1936-11-24 | 67.789185 | 0.396988 | NaN | 0 | never smoker | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | 0 | NaN | 0 | 2004-05-14 | NaN | 2004-04-16 | NaN | 0 | 0.284736 | 0 | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | 0.0 | 0.0 | 0 | White | 0.898630 | 55.112329 | 73.243836 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0.898630 | NaN | NaN | NaN | 70.109589 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 66.205479 | NaN | 68.241096 | NaN | NaN | 68.241096 | 73.090411 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0.898630 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 68.413699 | NaN | 68.336986 | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 36808 | 8618764 | 764 | 1931-07-14 | female | 1931-01-01 | 1984-08-01 | 2003-11-21 | 2003-11-21 | 01jan1987 | 21nov2003 | started AP | lithium | 1 | 03jul1931 | 23aug1999 | 07apr1999 | 1 | 0.0 | 0.0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 1931-07-14 | 1 | mania | 1 | NaN | 0 | 0 | 67.731689 | 0.000000 | mania | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | 0 | NaN | 0 | NaN | NaN | NaN | 0 | NaN | 0 | NaN | NaN | healthy weight | NaN | NaN | 0.377823 | 1996-09-09 | NaN | 0 | 0 | NaN | 70.0 | 1995-11-07 | 120.0 | 0 | NaN | 0 | NaN | NaN | 0 | NaN | NaN | 0.0 | NaN | NaN | NaN | 1931-07-14 | 1 | NaN | NaN | NaN | 0 | 0 | GOLD | mania | NaN | 0 | NaN | 1931-07-14 | 67.761810 | 0.030116 | NaN | 0 | current smoker | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | 0 | NaN | NaN | 0 | NaN | 0 | 1999-08-23 | NaN | 1999-04-07 | NaN | 0 | 4.219028 | 0 | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | NaN | NaN | NaN | 0 | 0 | 0.0 | 0.0 | 0 | White | 0.531507 | 53.619178 | 72.936986 | 72.936986 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0.531507 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 65.734247 | NaN | NaN | 64.893151 | NaN | NaN | NaN | NaN | 0.531507 | NaN | NaN | NaN | 0.531507 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 68.687671 | NaN | 68.309589 | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
We modify those illicit values
df.loc[df.symptom_to_exposure<0, 'symptom_to_exposure'] = 0
df.loc[df.age_first_diagnosis<0, 'age_first_diagnosis'] = np.nan
df.loc[df.symptom_to_exposure>=80, 'symptom_to_exposure'] = np.nan
target_multiclass = df.response2_1.replace(1, 2).fillna(1)
df['target_multiclass'] = target_multiclass
target_lithium2y = (df.exposure=='lithium') & (df.response2_1==1)
df['target_lithium2y'] = target_lithium2y.astype(int)
target_exposure = (df.exposure=='olanzapine')
df['target_exposure'] = target_exposure.astype(int)
targets = {
'resp': 'response2_1',
'exp': 'target_exposure', # 0=lithium 1=olanzapine
'multi': 'target_multiclass',
'lithium2y': 'target_lithium2y'
}
for target in targets:
print(df[targets[target]].value_counts())
print()
0.0 14785 1.0 11848 Name: response2_1, dtype: int64 0 19106 1 12412 Name: target_exposure, dtype: int64 0.0 14785 2.0 11848 1.0 4885 Name: target_multiclass, dtype: int64 0 23501 1 8017 Name: target_lithium2y, dtype: int64
We have a lot of "age" features. The ones that have few missing values are good candidates: let's keep only the ages with less than 5000 missing values.
age_columns = [col for col in df.columns if col.startswith('age_')]
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
age_df = df[age_columns].isna().sum().sort_values(ascending=True)
display(age_df)
features_age = age_df[age_df < 5000].index.to_list()
age_first_exposure 0 age_first_reg 0 age_diagnosis 21 age_first_diagnosis 26 age_first 143 age_smoke 968 age_BP 3102 age_eGFR 5370 age_TSH 9770 age_BMI 10303 age_first_li 10466 age_first_AP 11000 age_depression 11533 age_ethnicity 12565 age_last_SSRI 12704 age_first_SSRI 12722 age_HDL 14024 age_first_olan 14456 age_transfer_out 15942 age_LDL 16581 age_ca 16668 age_anxiety 17770 age_first_MS 17786 age_last_TCAs 18235 age_first_TCAs 18258 age_psychosis 19271 age_last_other_ADs 19447 age_first_other_ADs 19451 age_sleep 21187 age_mania 21830 age_Chronic_pulmonary_disease 23645 age_self_harm 24530 age_dermatitis 24955 age_Diabetes_uncomplicated 25385 age_stress 25915 age_asthma 25949 age_Neurological_disorders 26065 age_T2DM 26324 age_alcohol 26479 age_death 26786 age_hypothyroid 27231 age_PD 28117 age_CHD 28361 age_relationship 28484 age_migraine 28782 age_other_substance_misuse 28977 age_Deficiency_anaemia 29025 age_cardiac_arrythmia 29594 age_Weight_loss 29594 age_Fluid_electrolyte_disorder 29903 age_Diabetes_organ_damage 30053 age_RA 30419 age_Congestive_heart_failure 30524 age_OCD 30601 age_Peptic_ulcer 30629 age_Peripheral_vascular 30642 age_cannabis 30764 age_Liver_disease 30781 age_Valvular_disease 30950 age_psych_FH 31014 age_FH_NOS 31052 age_Pulmonary_circulation 31056 age_hyperthyroid 31077 age_FH_depression 31108 age_Coagulopathy 31208 age_adhd 31217 age_FH_BPD 31234 age_FH_psychosis 31267 age_conduct 31275 age_HIV_AIDS 31289 age_FH_suicide 31383 age_FH_substance 31398 age_FH_anxiety 31492 age_FH_LD 31509 dtype: int64
We now remove features that do not make sense from a clinical perspective, possibly confusing the learning process
features_age.remove('age_BP')
features_age.remove('age_first')
features_age.remove('age_diagnosis')
features_age.remove('age_smoke')
print('age features:', features_age)
age features: ['age_first_exposure', 'age_first_reg', 'age_first_diagnosis']
We now select a good list of features, agnostic as well as informed
# list of agnostic features (to complement with age features)
features_agnostic = """adhd
FH_suicide
mania
psychosis
relationship
self_harm
sex
sleep
smoker
T2DM
OCD
migraine
hypothyroid
CHD
other_substance_misuse
cannabis
alcohol
depression
FH_anxiety
FH_any
FH_BPD
FH_depression
FH_LD
FH_psychosis
N_dep_b4
N_man_b4
anxiety
stress
hi_LDL
lo_HDL
weight
CKD3
symptom_to_exposure
dominant"""
# list of informed features:
features_informed = """age_first_exposure
age_first_diagnosis
symptom_to_exposure
psychosis
depression
mania
dominant
sex
FH_BPD
FH_depression
FH_psychosis
weight
self_harm
cannabis
anxiety
stress
sleep
other_substance_misuse
relationship
OCD
adhd
smoker
alcohol
FH_suicide
hi_LDL
lo_HDL
CKD3
T2DM
migraine
hypothyroid
CHD
FH_anxiety
FH_any
FH_LD"""
shaky_features="""cannabis
anxiety
stress
sleep
other_substance_misuse
relationship
OCD
adhd
smoker
alcohol
FH_suicide
hi_LDL
lo_HDL
CKD3
T2DM
migraine
hypothyroid
CHD
FH_anxiety
FH_any
FH_LD"""
shaky_features = [ feature.strip() for feature in shaky_features.split('\n') ]
features = {
# informed
'34': [ feature.strip() for feature in features_informed.split('\n') ],
# agnostic
'37' : [ feature.strip() for feature in features_agnostic.split("\n") ] + features_age
}
features.update({
'13' : [feature for feature in features['34'] if feature not in shaky_features]
})
print(len(features['34']), 'informed features:', features['34'])
print(len(features['37']), 'agnostic features:', features['37'])
print(len(features['13']), 'important features:', features['13'])
print(len(shaky_features), 'shaky features:', shaky_features)
34 informed features: ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm', 'cannabis', 'anxiety', 'stress', 'sleep', 'other_substance_misuse', 'relationship', 'OCD', 'adhd', 'smoker', 'alcohol', 'FH_suicide', 'hi_LDL', 'lo_HDL', 'CKD3', 'T2DM', 'migraine', 'hypothyroid', 'CHD', 'FH_anxiety', 'FH_any', 'FH_LD'] 37 agnostic features: ['adhd', 'FH_suicide', 'mania', 'psychosis', 'relationship', 'self_harm', 'sex', 'sleep', 'smoker', 'T2DM', 'OCD', 'migraine', 'hypothyroid', 'CHD', 'other_substance_misuse', 'cannabis', 'alcohol', 'depression', 'FH_anxiety', 'FH_any', 'FH_BPD', 'FH_depression', 'FH_LD', 'FH_psychosis', 'N_dep_b4', 'N_man_b4', 'anxiety', 'stress', 'hi_LDL', 'lo_HDL', 'weight', 'CKD3', 'symptom_to_exposure', 'dominant', 'age_first_exposure', 'age_first_reg', 'age_first_diagnosis'] 13 important features: ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm'] 21 shaky features: ['cannabis', 'anxiety', 'stress', 'sleep', 'other_substance_misuse', 'relationship', 'OCD', 'adhd', 'smoker', 'alcohol', 'FH_suicide', 'hi_LDL', 'lo_HDL', 'CKD3', 'T2DM', 'migraine', 'hypothyroid', 'CHD', 'FH_anxiety', 'FH_any', 'FH_LD']
print("Agnostic features that are not informed features:", list(set(features['37']) - set(features['34'])))
print("Informed features that are not agnostic features:", list(set(features['34']) - set(features['37'])))
Agnostic features that are not informed features: ['N_man_b4', 'N_dep_b4', 'age_first_reg'] Informed features that are not agnostic features: []
Here we prepare the X dataframe (samples with their features) and the y dataframe (target to predict, for each sample). We also fix the types of a few features to make sure the algorithms will interpret them correctly.
from sklearn import preprocessing, metrics
def prepare(features, target):
# We first remove the samples with N/A for any of the features or label
df_withoutNA = df.dropna(subset=(features + [target]))
df_features = df_withoutNA[features]
# Now we encode all string values in the features into an int value (we might use One Hot encoding later?)
le = preprocessing.LabelEncoder()
to_encode = [key for key in dict(df_features.dtypes) if dict(df.dtypes)[key] not in ['float64', 'int64']]
new_df_features = df_features.copy()
new_df_features.update(df_features[to_encode].apply(le.fit_transform))
new_df_features[to_encode] = new_df_features[to_encode].astype(np.int64)
# Because lo_HDL has missing values, it was interpreted by pandas with
# floats. Now that we removed the missing values, we can interpret as int
if 'lo_HDL' in features:
new_df_features.lo_HDL = new_df_features.lo_HDL.astype(np.int64)
# Now we have our X and y
return new_df_features, df_withoutNA[target]
X = dict()
y = dict()
for feature in features:
for target in targets:
X[feature + '_' + target], y[feature + '_' + target] = prepare(features[feature] + ['exposure'], targets[target])
if target != 'exp':
y[feature + '_' + target] = y[feature + '_' + target].astype(np.int64)
print(feature + '_' + target, len(X[feature + '_' + target]), 'rows')
34_resp 26509 rows 34_exp 31369 rows 34_multi 31369 rows 34_lithium2y 31369 rows 37_resp 26509 rows 37_exp 31369 rows 37_multi 31369 rows 37_lithium2y 31369 rows 13_resp 26510 rows 13_exp 31370 rows 13_multi 31370 rows 13_lithium2y 31370 rows
from scipy.stats import pearsonr, spearmanr
import seaborn as sns
# The difference between the X dictionnary and the X_dict dictionnary is that we drop 'exposure' in X-dict. We don't want it as a feature in X-dict but we need it in X to know the exposure of the patient.
X_dict = dict()
y_dict = dict()
for key in X:
X_dict.update({
key: X[key].drop('exposure', axis=1)
})
X_dict.update({
'num_' + key: X_dict[key].loc[:, X_dict[key].dtypes == np.float64],
'cat_' + key: X_dict[key].loc[:, X_dict[key].dtypes == np.int64],
})
X_dict.update({
'bin_' + key: X_dict['cat_' + key].loc[:, X_dict['cat_' + key].nunique() == 2]
})
y_dict.update({
key: y[key],
'num_' + key: y[key],
'cat_' + key: y[key],
'bin_' + key: y[key],
})
# When we predict exposure or lithium2y, it doesn't make sense to predict separately exposures
if '_exp' not in key and '_lithium2y' not in key:
X_dict.update({
'lit_' + key: X[key].loc[X[key].exposure == 0].drop('exposure', axis=1),
'ola_' + key: X[key].loc[X[key].exposure == 1].drop('exposure', axis=1),
})
y_dict.update({
'lit_' + key : y[key][X[key].loc[X[key].exposure == 0].index],
'ola_' + key : y[key][X[key].loc[X[key].exposure == 1].index],
})
print(X_dict.keys())
print(y_dict.keys())
dict_keys(['34_resp', 'num_34_resp', 'cat_34_resp', 'bin_34_resp', 'lit_34_resp', 'ola_34_resp', '34_exp', 'num_34_exp', 'cat_34_exp', 'bin_34_exp', '34_multi', 'num_34_multi', 'cat_34_multi', 'bin_34_multi', 'lit_34_multi', 'ola_34_multi', '34_lithium2y', 'num_34_lithium2y', 'cat_34_lithium2y', 'bin_34_lithium2y', '37_resp', 'num_37_resp', 'cat_37_resp', 'bin_37_resp', 'lit_37_resp', 'ola_37_resp', '37_exp', 'num_37_exp', 'cat_37_exp', 'bin_37_exp', '37_multi', 'num_37_multi', 'cat_37_multi', 'bin_37_multi', 'lit_37_multi', 'ola_37_multi', '37_lithium2y', 'num_37_lithium2y', 'cat_37_lithium2y', 'bin_37_lithium2y', '13_resp', 'num_13_resp', 'cat_13_resp', 'bin_13_resp', 'lit_13_resp', 'ola_13_resp', '13_exp', 'num_13_exp', 'cat_13_exp', 'bin_13_exp', '13_multi', 'num_13_multi', 'cat_13_multi', 'bin_13_multi', 'lit_13_multi', 'ola_13_multi', '13_lithium2y', 'num_13_lithium2y', 'cat_13_lithium2y', 'bin_13_lithium2y']) dict_keys(['34_resp', 'num_34_resp', 'cat_34_resp', 'bin_34_resp', 'lit_34_resp', 'ola_34_resp', '34_exp', 'num_34_exp', 'cat_34_exp', 'bin_34_exp', '34_multi', 'num_34_multi', 'cat_34_multi', 'bin_34_multi', 'lit_34_multi', 'ola_34_multi', '34_lithium2y', 'num_34_lithium2y', 'cat_34_lithium2y', 'bin_34_lithium2y', '37_resp', 'num_37_resp', 'cat_37_resp', 'bin_37_resp', 'lit_37_resp', 'ola_37_resp', '37_exp', 'num_37_exp', 'cat_37_exp', 'bin_37_exp', '37_multi', 'num_37_multi', 'cat_37_multi', 'bin_37_multi', 'lit_37_multi', 'ola_37_multi', '37_lithium2y', 'num_37_lithium2y', 'cat_37_lithium2y', 'bin_37_lithium2y', '13_resp', 'num_13_resp', 'cat_13_resp', 'bin_13_resp', 'lit_13_resp', 'ola_13_resp', '13_exp', 'num_13_exp', 'cat_13_exp', 'bin_13_exp', '13_multi', 'num_13_multi', 'cat_13_multi', 'bin_13_multi', 'lit_13_multi', 'ola_13_multi', '13_lithium2y', 'num_13_lithium2y', 'cat_13_lithium2y', 'bin_13_lithium2y'])
display(X_dict['cat_13_lithium2y'])
| psychosis | depression | mania | dominant | sex | FH_BPD | FH_depression | FH_psychosis | weight | self_harm | |
|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
| 3 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 2 | 0 |
| 4 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
| 5 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 38950 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
| 38951 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 38953 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 |
| 38954 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 38956 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
31370 rows × 10 columns
%matplotlib inline
import seaborn as sns
sns.set_theme(style="white")
sns.set(font_scale=2)
from matplotlib.ticker import FormatStrFormatter
values = {0: df.age_first_exposure,
1: df.age_first_diagnosis}
fig, ax = plt.subplots(ncols=2, nrows=1, sharey=True)
plt.subplots_adjust(hspace=0.6)
fig.set_size_inches(10, 6)
fig.tight_layout()
for i in values.keys():
sns.distplot(values[i], kde=False, ax=ax[i], bins=range(100))
# ax[i].set_yscale('log')
ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
ax[i].set_xlim((0,100))
display(df.age_first_diagnosis)
/Users/fehmi/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
1 80.722794
2 41.752224
3 56.533882
4 21.207392
5 82.948662
...
38950 55.764545
38951 46.135525
38953 28.867899
38954 31.498974
38956 38.529774
Name: age_first_diagnosis, Length: 31518, dtype: float64
Now let's check the number of missing values for each feature, and look at their entropy
import scipy.stats as st
def entropy3(labels, base=None):
vc = pd.Series(labels).value_counts(normalize=False, sort=False)
base = 2 if base is None else base
#return -(vc * np.log(vc)/np.log(base)).sum()
return st.entropy(vc)
nan_entropy = pd.concat([df[features['37']].isna().sum(), df[features['37']].apply(entropy3, axis=0)], axis=1)
nan_entropy.columns = ['N/A', 'Entropy']
display(nan_entropy)
display(df[features['37']])
| N/A | Entropy | |
|---|---|---|
| adhd | 0 | 0.033217 |
| FH_suicide | 0 | 0.027631 |
| mania | 0 | 0.573333 |
| psychosis | 0 | 0.590539 |
| relationship | 0 | 0.390328 |
| self_harm | 0 | 0.391580 |
| sex | 0 | 0.679209 |
| sleep | 0 | 0.379125 |
| smoker | 0 | 1.066039 |
| T2DM | 0 | 0.175462 |
| OCD | 0 | 0.096136 |
| migraine | 0 | 0.203815 |
| hypothyroid | 0 | 0.201039 |
| CHD | 0 | 0.145867 |
| other_substance_misuse | 0 | 0.191553 |
| cannabis | 0 | 0.078576 |
| alcohol | 0 | 0.224948 |
| depression | 0 | 0.675798 |
| FH_anxiety | 0 | 0.010781 |
| FH_any | 0 | 0.233201 |
| FH_BPD | 0 | 0.086404 |
| FH_depression | 0 | 0.096757 |
| FH_LD | 0 | 0.002616 |
| FH_psychosis | 0 | 0.046572 |
| N_dep_b4 | 0 | 1.995139 |
| N_man_b4 | 0 | 0.893913 |
| anxiety | 0 | 0.547383 |
| stress | 0 | 0.302052 |
| hi_LDL | 0 | 0.228529 |
| lo_HDL | 1 | 0.119001 |
| weight | 0 | 1.014700 |
| CKD3 | 0 | 0.094266 |
| symptom_to_exposure | 143 | 8.074818 |
| dominant | 0 | 0.988668 |
| age_first_exposure | 0 | 9.559759 |
| age_first_reg | 0 | 9.567249 |
| age_first_diagnosis | 26 | 8.743020 |
| adhd | FH_suicide | mania | psychosis | relationship | self_harm | sex | sleep | smoker | T2DM | OCD | migraine | hypothyroid | CHD | other_substance_misuse | cannabis | alcohol | depression | FH_anxiety | FH_any | FH_BPD | FH_depression | FH_LD | FH_psychosis | N_dep_b4 | N_man_b4 | anxiety | stress | hi_LDL | lo_HDL | weight | CKD3 | symptom_to_exposure | dominant | age_first_exposure | age_first_reg | age_first_diagnosis | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 0 | 0 | 1 | 0 | 0 | female | 0 | ex smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0.0 | healthy weight | 0 | 0.000000 | unclear | 69.434631 | 64.391781 | 80.722794 |
| 2 | 0 | 0 | 0 | 0 | 0 | 1 | female | 1 | ex smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0.0 | healthy weight | 0 | 24.128679 | depression | 43.627651 | 27.057534 | 41.752224 |
| 3 | 0 | 0 | 1 | 0 | 0 | 0 | female | 0 | current smoker | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0.0 | overweight | 0 | 0.000000 | mania | 56.533882 | 50.843836 | 56.533882 |
| 4 | 0 | 0 | 1 | 0 | 0 | 1 | female | 1 | ex smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 12 | 1 | 0 | 0 | 0 | 0.0 | healthy weight | 0 | 10.973306 | depression | 26.524298 | 22.583562 | 21.207392 |
| 5 | 0 | 0 | 0 | 1 | 0 | 0 | female | 0 | current smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0.0 | healthy weight | 0 | 13.349760 | depression | 72.032852 | 67.394521 | 82.948662 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 38950 | 0 | 0 | 0 | 0 | 1 | 1 | female | 0 | current smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0.0 | healthy weight | 1 | 18.521561 | unclear | 74.286102 | 45.613699 | 55.764545 |
| 38951 | 0 | 0 | 0 | 1 | 0 | 0 | female | 0 | never smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0.0 | healthy weight | 0 | 7.293634 | depression | 43.791924 | 44.260274 | 46.135525 |
| 38953 | 0 | 0 | 0 | 0 | 1 | 0 | female | 1 | current smoker | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 13 | 0 | 1 | 1 | 0 | 0.0 | overweight | 0 | 11.707050 | depression | 37.223820 | 45.471233 | 28.867899 |
| 38954 | 0 | 0 | 1 | 0 | 0 | 0 | female | 0 | ex smoker | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0.0 | healthy weight | 0 | 16.922655 | mania | 48.421631 | 39.983562 | 31.498974 |
| 38956 | 0 | 0 | 0 | 1 | 0 | 0 | female | 0 | current smoker | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0.0 | healthy weight | 0 | 14.631075 | depression | 39.129364 | 30.767123 | 38.529774 |
31518 rows × 37 columns
key = '34_exp'
import itertools
variables_pearson = X_dict['num_' + key].columns
for variables in itertools.combinations(variables_pearson, 2):
plt.figure(figsize=(16,9))
sns.regplot(x = variables[0], y = variables[1], data = X_dict['num_'+key],
line_kws={"color": "red"})
display(X_dict['num_'+ key])
| age_first_exposure | age_first_diagnosis | symptom_to_exposure | |
|---|---|---|---|
| 1 | 69.434631 | 80.722794 | 0.000000 |
| 2 | 43.627651 | 41.752224 | 24.128679 |
| 3 | 56.533882 | 56.533882 | 0.000000 |
| 4 | 26.524298 | 21.207392 | 10.973306 |
| 5 | 72.032852 | 82.948662 | 13.349760 |
| ... | ... | ... | ... |
| 38950 | 74.286102 | 55.764545 | 18.521561 |
| 38951 | 43.791924 | 46.135525 | 7.293634 |
| 38953 | 37.223820 | 28.867899 | 11.707050 |
| 38954 | 48.421631 | 31.498974 | 16.922655 |
| 38956 | 39.129364 | 38.529774 | 14.631075 |
31369 rows × 3 columns
from pyitlib import discrete_random_variable as drv
from scipy.spatial import distance
def similarity(X,Y):
sim = ((X*Y)+((1-X)*(1-Y)))/len(X)
return(sim.sum())
def coeff(X_dict, y, key, feature):
print(len(features[feature]), 'features')
df_coeff = pd.DataFrame(index=features[feature])
# Now we calculate both for each numerical feature
for feat in features[feature]:
# We calculate correlation coeffs only for numerical and binary categorical
if (feat in X_dict['num_' + key].columns):
df_coeff.loc[feat, 'pearson_r'], df_coeff.loc[feat, 'pearson_p'] = pearsonr(X[key][feat], y[key])
df_coeff.loc[feat, 'spearman_r'], df_coeff.loc[feat, 'spearman_p'] = spearmanr(X[key][feat], y[key])
# We calculate conditional entropy for only categorical (including non binary)
if feat in X_dict['cat_'+ key].columns:
df_coeff.loc[feat, 'cond_entropy'] = drv.entropy_conditional([int(x) for x in X[key][feat]], [int(x) for x in y[key]], base=np.e)
# For binary variables the Jaccard similarity is more approriate
if (feat in X_dict['bin_' + key].columns) and (y[key].dtype != np.float64):
df_coeff.loc[feat, 'jaccard'] = 1-distance.jaccard(X[key][feat], y[key])
df_coeff.loc[feat, 'SMC'] = similarity(X[key][feat], y[key])
# We display the sorted absolute values only when p-value < 0.01
display(df_coeff
.reindex(df_coeff.pearson_r.abs().sort_values(ascending=False).index))
return df_coeff
For binary features:
feature = '34'
for target in targets:
key = feature + '_' + target
print(10*'_' + target + 10*'_')
df_coeff = coeff(X_dict, y, key, feature)
print()
__________resp__________ 34 features
| pearson_r | pearson_p | spearman_r | spearman_p | cond_entropy | jaccard | SMC | |
|---|---|---|---|---|---|---|---|
| age_first_exposure | 0.193293 | 1.693068e-221 | 0.195700 | 4.192704e-227 | NaN | NaN | NaN |
| age_first_diagnosis | 0.131446 | 1.780416e-102 | 0.130364 | 8.174834e-101 | NaN | NaN | NaN |
| symptom_to_exposure | 0.122443 | 4.498120e-89 | 0.119015 | 3.168104e-84 | NaN | NaN | NaN |
| psychosis | NaN | NaN | NaN | NaN | 0.589143 | 0.209603 | 0.529292 |
| depression | NaN | NaN | NaN | NaN | 0.676162 | 0.345302 | 0.495832 |
| mania | NaN | NaN | NaN | NaN | 0.574144 | 0.197438 | 0.527255 |
| dominant | NaN | NaN | NaN | NaN | 0.989417 | NaN | NaN |
| sex | NaN | NaN | NaN | NaN | 0.679791 | 0.274790 | 0.509186 |
| FH_BPD | NaN | NaN | NaN | NaN | 0.086441 | 0.015442 | 0.552642 |
| FH_depression | NaN | NaN | NaN | NaN | 0.095472 | 0.017635 | 0.552416 |
| FH_psychosis | NaN | NaN | NaN | NaN | 0.044860 | 0.006384 | 0.553774 |
| weight | NaN | NaN | NaN | NaN | 1.014475 | NaN | NaN |
| self_harm | NaN | NaN | NaN | NaN | 0.391276 | 0.108984 | 0.536459 |
| cannabis | NaN | NaN | NaN | NaN | 0.076336 | 0.011893 | 0.551813 |
| anxiety | NaN | NaN | NaN | NaN | 0.548561 | 0.172467 | 0.518352 |
| stress | NaN | NaN | NaN | NaN | 0.300254 | 0.064856 | 0.531140 |
| sleep | NaN | NaN | NaN | NaN | 0.379208 | 0.104652 | 0.537516 |
| other_substance_misuse | NaN | NaN | NaN | NaN | 0.192186 | 0.044094 | 0.549398 |
| relationship | NaN | NaN | NaN | NaN | 0.392561 | 0.113042 | 0.539741 |
| OCD | NaN | NaN | NaN | NaN | 0.098012 | 0.019634 | 0.553586 |
| adhd | NaN | NaN | NaN | NaN | 0.033642 | 0.004212 | 0.554038 |
| smoker | NaN | NaN | NaN | NaN | 1.065933 | NaN | NaN |
| alcohol | NaN | NaN | NaN | NaN | 0.227904 | 0.049326 | 0.542684 |
| FH_suicide | NaN | NaN | NaN | NaN | 0.026897 | 0.004309 | 0.555396 |
| hi_LDL | NaN | NaN | NaN | NaN | 0.229586 | 0.055144 | 0.547550 |
| lo_HDL | NaN | NaN | NaN | NaN | 0.119419 | 0.025348 | 0.553246 |
| CKD3 | NaN | NaN | NaN | NaN | 0.092342 | 0.019442 | 0.554793 |
| T2DM | NaN | NaN | NaN | NaN | 0.175593 | 0.049850 | 0.559244 |
| migraine | NaN | NaN | NaN | NaN | 0.203305 | 0.044653 | 0.546418 |
| hypothyroid | NaN | NaN | NaN | NaN | 0.201963 | 0.057541 | 0.558226 |
| CHD | NaN | NaN | NaN | NaN | 0.144386 | 0.036291 | 0.556226 |
| FH_anxiety | NaN | NaN | NaN | NaN | 0.009570 | 0.001187 | 0.555509 |
| FH_any | NaN | NaN | NaN | NaN | 0.231687 | 0.052400 | 0.544306 |
| FH_LD | NaN | NaN | NaN | NaN | 0.002625 | 0.000594 | 0.555924 |
__________exp__________ 34 features
| pearson_r | pearson_p | spearman_r | spearman_p | cond_entropy | jaccard | SMC | |
|---|---|---|---|---|---|---|---|
| age_first_exposure | -0.106543 | 7.323420e-80 | -0.108190 | 2.643016e-82 | NaN | NaN | NaN |
| symptom_to_exposure | -0.075266 | 1.201673e-40 | -0.082721 | 9.258229e-49 | NaN | NaN | NaN |
| age_first_diagnosis | -0.059584 | 4.470683e-26 | -0.054850 | 2.440734e-22 | NaN | NaN | NaN |
| psychosis | NaN | NaN | NaN | NaN | 0.589280 | 0.219806 | 0.570021 |
| depression | NaN | NaN | NaN | NaN | 0.676135 | 0.315729 | 0.487009 |
| mania | NaN | NaN | NaN | NaN | 0.572370 | 0.204563 | 0.567630 |
| dominant | NaN | NaN | NaN | NaN | 0.988610 | NaN | NaN |
| sex | NaN | NaN | NaN | NaN | 0.678129 | 0.278029 | 0.541394 |
| FH_BPD | NaN | NaN | NaN | NaN | 0.085809 | 0.010562 | 0.596831 |
| FH_depression | NaN | NaN | NaN | NaN | 0.096748 | 0.019133 | 0.601231 |
| FH_psychosis | NaN | NaN | NaN | NaN | 0.046308 | 0.011288 | 0.606299 |
| weight | NaN | NaN | NaN | NaN | 1.010127 | NaN | NaN |
| self_harm | NaN | NaN | NaN | NaN | 0.391315 | 0.123489 | 0.588415 |
| cannabis | NaN | NaN | NaN | NaN | 0.075429 | 0.027083 | 0.611782 |
| anxiety | NaN | NaN | NaN | NaN | 0.541934 | 0.224798 | 0.600179 |
| stress | NaN | NaN | NaN | NaN | 0.293544 | 0.126974 | 0.624534 |
| sleep | NaN | NaN | NaN | NaN | 0.373184 | 0.150662 | 0.615416 |
| other_substance_misuse | NaN | NaN | NaN | NaN | 0.185501 | 0.074975 | 0.619274 |
| relationship | NaN | NaN | NaN | NaN | 0.388881 | 0.092250 | 0.562402 |
| OCD | NaN | NaN | NaN | NaN | 0.095491 | 0.025091 | 0.606108 |
| adhd | NaN | NaN | NaN | NaN | 0.033051 | 0.007548 | 0.605980 |
| smoker | NaN | NaN | NaN | NaN | 1.060558 | NaN | NaN |
| alcohol | NaN | NaN | NaN | NaN | 0.222606 | 0.075436 | 0.609678 |
| FH_suicide | NaN | NaN | NaN | NaN | 0.027568 | 0.004173 | 0.604386 |
| hi_LDL | NaN | NaN | NaN | NaN | 0.223190 | 0.087069 | 0.617616 |
| lo_HDL | NaN | NaN | NaN | NaN | 0.117458 | 0.036801 | 0.609519 |
| CKD3 | NaN | NaN | NaN | NaN | 0.093997 | 0.023179 | 0.605024 |
| T2DM | NaN | NaN | NaN | NaN | 0.175785 | 0.039021 | 0.595684 |
| migraine | NaN | NaN | NaN | NaN | 0.203515 | 0.056255 | 0.601039 |
| hypothyroid | NaN | NaN | NaN | NaN | 0.200442 | 0.037254 | 0.586439 |
| CHD | NaN | NaN | NaN | NaN | 0.145779 | 0.035962 | 0.601772 |
| FH_anxiety | NaN | NaN | NaN | NaN | 0.010400 | 0.001532 | 0.605183 |
| FH_any | NaN | NaN | NaN | NaN | 0.233395 | 0.057539 | 0.592719 |
| FH_LD | NaN | NaN | NaN | NaN | 0.002539 | 0.000565 | 0.605502 |
__________multi__________ 34 features
| pearson_r | pearson_p | spearman_r | spearman_p | cond_entropy | jaccard | SMC | |
|---|---|---|---|---|---|---|---|
| age_first_exposure | 0.176324 | 1.914504e-217 | 0.177382 | 4.462611e-220 | NaN | NaN | NaN |
| age_first_diagnosis | 0.120163 | 3.210233e-101 | 0.118399 | 2.643382e-98 | NaN | NaN | NaN |
| symptom_to_exposure | 0.112088 | 3.062966e-88 | 0.107891 | 7.372539e-82 | NaN | NaN | NaN |
| psychosis | NaN | NaN | NaN | NaN | 0.590599 | 0.067322 | 0.327106 |
| depression | NaN | NaN | NaN | NaN | 0.676158 | 0.113832 | 0.584686 |
| mania | NaN | NaN | NaN | NaN | 0.573145 | 0.060513 | 0.306162 |
| dominant | NaN | NaN | NaN | NaN | 0.989188 | NaN | NaN |
| sex | NaN | NaN | NaN | NaN | 0.679341 | 0.087408 | 0.432688 |
| FH_BPD | NaN | NaN | NaN | NaN | 0.086569 | 0.004969 | 0.106092 |
| FH_depression | NaN | NaN | NaN | NaN | 0.096706 | 0.006140 | 0.108260 |
| FH_psychosis | NaN | NaN | NaN | NaN | 0.046645 | 0.002923 | 0.098919 |
| weight | NaN | NaN | NaN | NaN | 1.014602 | NaN | NaN |
| self_harm | NaN | NaN | NaN | NaN | 0.391790 | 0.035012 | 0.194523 |
| cannabis | NaN | NaN | NaN | NaN | 0.078415 | 0.005153 | 0.102745 |
| anxiety | NaN | NaN | NaN | NaN | 0.547361 | 0.055399 | 0.268067 |
| stress | NaN | NaN | NaN | NaN | 0.301458 | 0.024737 | 0.142657 |
| sleep | NaN | NaN | NaN | NaN | 0.379355 | 0.033202 | 0.189773 |
| other_substance_misuse | NaN | NaN | NaN | NaN | 0.191935 | 0.013252 | 0.131276 |
| relationship | NaN | NaN | NaN | NaN | 0.389897 | 0.032445 | 0.199050 |
| OCD | NaN | NaN | NaN | NaN | 0.096090 | 0.004843 | 0.110077 |
| adhd | NaN | NaN | NaN | NaN | 0.033256 | 0.001434 | 0.096688 |
| smoker | NaN | NaN | NaN | NaN | 1.065385 | NaN | NaN |
| alcohol | NaN | NaN | NaN | NaN | 0.224836 | 0.014706 | 0.131499 |
| FH_suicide | NaN | NaN | NaN | NaN | 0.027556 | 0.001437 | 0.097899 |
| hi_LDL | NaN | NaN | NaN | NaN | 0.228775 | 0.016350 | 0.141031 |
| lo_HDL | NaN | NaN | NaN | NaN | 0.119321 | 0.007289 | 0.115656 |
| CKD3 | NaN | NaN | NaN | NaN | 0.094304 | 0.006274 | 0.111671 |
| T2DM | NaN | NaN | NaN | NaN | 0.175073 | 0.011715 | 0.142625 |
| migraine | NaN | NaN | NaN | NaN | 0.203864 | 0.014731 | 0.130320 |
| hypothyroid | NaN | NaN | NaN | NaN | 0.200578 | 0.013537 | 0.149319 |
| CHD | NaN | NaN | NaN | NaN | 0.145722 | 0.010078 | 0.128311 |
| FH_anxiety | NaN | NaN | NaN | NaN | 0.010378 | 0.000600 | 0.095190 |
| FH_any | NaN | NaN | NaN | NaN | 0.233232 | 0.018060 | 0.137237 |
| FH_LD | NaN | NaN | NaN | NaN | 0.002521 | 0.000060 | 0.094807 |
__________lithium2y__________ 34 features
| pearson_r | pearson_p | spearman_r | spearman_p | cond_entropy | jaccard | SMC | |
|---|---|---|---|---|---|---|---|
| age_first_exposure | 0.154052 | 7.457636e-166 | 0.157628 | 1.206823e-173 | NaN | NaN | NaN |
| age_first_diagnosis | 0.104651 | 4.215269e-77 | 0.101220 | 3.181159e-72 | NaN | NaN | NaN |
| symptom_to_exposure | 0.103478 | 2.051218e-75 | 0.101697 | 6.819342e-73 | NaN | NaN | NaN |
| psychosis | NaN | NaN | NaN | NaN | 0.590447 | 0.142651 | 0.601294 |
| depression | NaN | NaN | NaN | NaN | 0.676061 | 0.222709 | 0.462495 |
| mania | NaN | NaN | NaN | NaN | 0.572964 | 0.137856 | 0.610635 |
| dominant | NaN | NaN | NaN | NaN | 0.988927 | NaN | NaN |
| sex | NaN | NaN | NaN | NaN | 0.679087 | 0.176928 | 0.530779 |
| FH_BPD | NaN | NaN | NaN | NaN | 0.086570 | 0.018090 | 0.738723 |
| FH_depression | NaN | NaN | NaN | NaN | 0.096748 | 0.018281 | 0.736364 |
| FH_psychosis | NaN | NaN | NaN | NaN | 0.046481 | 0.004525 | 0.740540 |
| weight | NaN | NaN | NaN | NaN | 1.014417 | NaN | NaN |
| self_harm | NaN | NaN | NaN | NaN | 0.391424 | 0.082263 | 0.672097 |
| cannabis | NaN | NaN | NaN | NaN | 0.077634 | 0.006322 | 0.734419 |
| anxiety | NaN | NaN | NaN | NaN | 0.546391 | 0.115886 | 0.611113 |
| stress | NaN | NaN | NaN | NaN | 0.299668 | 0.041727 | 0.683732 |
| sleep | NaN | NaN | NaN | NaN | 0.378212 | 0.070755 | 0.670088 |
| other_substance_misuse | NaN | NaN | NaN | NaN | 0.190062 | 0.022924 | 0.711945 |
| relationship | NaN | NaN | NaN | NaN | 0.389486 | 0.109298 | 0.690331 |
| OCD | NaN | NaN | NaN | NaN | 0.095951 | 0.014316 | 0.734419 |
| adhd | NaN | NaN | NaN | NaN | 0.033194 | 0.003209 | 0.742516 |
| smoker | NaN | NaN | NaN | NaN | 1.062893 | NaN | NaN |
| alcohol | NaN | NaN | NaN | NaN | 0.223868 | 0.032913 | 0.706812 |
| FH_suicide | NaN | NaN | NaN | NaN | 0.027538 | 0.005091 | 0.744557 |
| hi_LDL | NaN | NaN | NaN | NaN | 0.227714 | 0.035373 | 0.707036 |
| lo_HDL | NaN | NaN | NaN | NaN | 0.119160 | 0.019302 | 0.731136 |
| CKD3 | NaN | NaN | NaN | NaN | 0.094344 | 0.016388 | 0.735950 |
| T2DM | NaN | NaN | NaN | NaN | 0.175457 | 0.046268 | 0.729924 |
| migraine | NaN | NaN | NaN | NaN | 0.203862 | 0.039913 | 0.717811 |
| hypothyroid | NaN | NaN | NaN | NaN | 0.200551 | 0.056950 | 0.728139 |
| CHD | NaN | NaN | NaN | NaN | 0.145969 | 0.029967 | 0.729638 |
| FH_anxiety | NaN | NaN | NaN | NaN | 0.010389 | 0.001001 | 0.745354 |
| FH_any | NaN | NaN | NaN | NaN | 0.233339 | 0.048752 | 0.713252 |
| FH_LD | NaN | NaN | NaN | NaN | 0.002609 | 0.000125 | 0.745991 |
from pandas import Series, DataFrame
import pandas as pd
import numpy as np
from collections import Counter
import os
from datetime import date
from sklearn.feature_selection import chi2
from scipy import stats
import seaborn as sns
import matplotlib.pylab as plt
from numpy import percentile
from sklearn.feature_selection import SelectKBest
def chisquare(Y):
X = Y.loc[:, Y.dtypes == np.int64]
column_names=X.columns
chisqmatrix=pd.DataFrame(X,columns=column_names,index=column_names)
outercnt=0
innercnt=0
for icol in column_names:
for jcol in column_names:
mycrosstab = pd.crosstab(X[icol], X[jcol])
stat, p, dof, expected=stats.chi2_contingency(mycrosstab)
chisqmatrix.iloc[outercnt,innercnt] = round(p,3)
cntexpected = expected[expected<5].size
perexpected = ((expected.size-cntexpected)/expected.size)*100
if perexpected < 20:
chisqmatrix.iloc[outercnt,innercnt] = 2
if icol==jcol:
chisqmatrix.iloc[outercnt,innercnt]=0.00
innercnt = innercnt + 1
outercnt = outercnt + 1
innercnt = 0
return chisqmatrix
df_corr_pearson = X_dict['num_'+ key].corr(method='pearson', min_periods=1)
df_corr_spearman = X_dict['num_' + key].corr(method='spearman', min_periods=1)
df_corr_chisquare = chisquare(X_dict['cat_' + key])
print(X_dict[key].dtypes)
age_first_exposure float64 age_first_diagnosis float64 symptom_to_exposure float64 psychosis int64 depression int64 mania int64 dominant int64 sex int64 FH_BPD int64 FH_depression int64 FH_psychosis int64 weight int64 self_harm int64 cannabis int64 anxiety int64 stress int64 sleep int64 other_substance_misuse int64 relationship int64 OCD int64 adhd int64 smoker int64 alcohol int64 FH_suicide int64 hi_LDL int64 lo_HDL int64 CKD3 int64 T2DM int64 migraine int64 hypothyroid int64 CHD int64 FH_anxiety int64 FH_any int64 FH_LD int64 dtype: object
import matplotlib.pyplot as plt
%matplotlib inline
sns.set_theme(style="white")
sns.set(font_scale=2)
def plot_matrix_corr(df_corr, fontsize=16, xsize=9, ysize=9):
cmap = sns.diverging_palette(230, 20, as_cmap=True)
plt.figure(figsize=(xsize,ysize))
sns.heatmap(df_corr, xticklabels = df_corr.columns, yticklabels = df_corr.columns, annot=True,
linewidths=0.5, cmap = cmap, fmt='.3f',
annot_kws={
'fontsize': fontsize,
'fontweight': 'bold',
'fontfamily': 'serif'}
)
plot_matrix_corr(df_corr_pearson, 32, 16, 12)
plot_matrix_corr(df_corr_spearman, 32, 16, 12)
sns.set_theme(style="white")
plot_matrix_corr(df_corr_chisquare, 8, 22, 14)
About Condition Entropy.
Here, $X$ is the response and $Y$ is the feature being considered. Two properties:
import scipy.stats as stats
def anova(X_dict, key, target):
anova_df = X_dict['num_' + key].merge(df[target], left_index=True, right_index=True)
results_list = list()
for feature in X_dict['num_' + key].keys():
res = stats.f_oneway(anova_df.loc[anova_df[target]==0.0, feature].values,
anova_df.loc[anova_df[target]==1.0, feature].values)
results_list = results_list + [[feature, res.statistic, res.pvalue]]
return (pd.DataFrame(results_list, columns=['feature', 'F-Statistic', 'p-value']))
feature = '34'
for target in targets:
key = feature + '_' + target
print('key:', key)
display(anova(X_dict, key, targets[target]))
key: 34_resp
| feature | F-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | 1028.801036 | 1.693068e-221 |
| 1 | age_first_diagnosis | 466.040468 | 1.780416e-102 |
| 2 | symptom_to_exposure | 403.450918 | 4.498120e-89 |
key: 34_exp
| feature | F-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | 360.151148 | 7.323420e-80 |
| 1 | age_first_diagnosis | 111.756084 | 4.470683e-26 |
| 2 | symptom_to_exposure | 178.703821 | 1.201673e-40 |
key: 34_multi
| feature | F-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | 46.909192 | 7.655806e-12 |
| 1 | age_first_diagnosis | 27.350178 | 1.714934e-07 |
| 2 | symptom_to_exposure | 13.520850 | 2.365581e-04 |
key: 34_lithium2y
| feature | F-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | 762.493116 | 7.457636e-166 |
| 1 | age_first_diagnosis | 347.333055 | 4.215269e-77 |
| 2 | symptom_to_exposure | 339.500883 | 2.051218e-75 |
import scipy.stats as stats
for target in targets:
key = feature + '_' + target
print('key:', key)
sample = 1000
ttest_df = X_dict['num_' + key].merge(df[targets[target]], left_index=True, right_index=True)
results_list = list()
for feat in X_dict['num_' + key].keys():
res = stats.ttest_ind(ttest_df.loc[ttest_df[targets[target]]==0.0, feat].head(sample),
ttest_df.loc[ttest_df[targets[target]]==1.0, feat].head(sample),
equal_var = True)
results_list = results_list + [[feat, res.statistic, res.pvalue]]
ttest_results = pd.DataFrame(results_list, columns=['feature', 'T-Statistic', 'p-value'])
display(ttest_results)
key: 34_resp
| feature | T-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | -8.560086 | 2.212296e-17 |
| 1 | age_first_diagnosis | -4.615607 | 4.168509e-06 |
| 2 | symptom_to_exposure | -6.759800 | 1.807514e-11 |
key: 34_exp
| feature | T-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | 5.359331 | 9.317705e-08 |
| 1 | age_first_diagnosis | 3.378131 | 7.438096e-04 |
| 2 | symptom_to_exposure | 3.309470 | 9.513413e-04 |
key: 34_multi
| feature | T-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | -3.410012 | 0.000663 |
| 1 | age_first_diagnosis | -2.348080 | 0.018967 |
| 2 | symptom_to_exposure | -2.479474 | 0.013240 |
key: 34_lithium2y
| feature | T-Statistic | p-value | |
|---|---|---|---|
| 0 | age_first_exposure | -7.933297 | 3.521460e-15 |
| 1 | age_first_diagnosis | -3.810229 | 1.430358e-04 |
| 2 | symptom_to_exposure | -6.630426 | 4.294092e-11 |
We split by time (exposure start). The 80% first patients as training set, the 20% last patients as test set.
x_train = dict()
y_train = dict()
x_test = dict()
y_test = dict()
format='%d%b%Y'
df['exposure_start'] = pd.to_datetime(df['exposure_start'], format=format)
for feature_set in X_dict:
print(feature_set)
X_sorted = X_dict[feature_set].merge(df['exposure_start'], left_index=True, right_index=True).sort_values(by = 'exposure_start')
y_sorted = y_dict[feature_set][X_sorted.index]
limit = int(len(X_sorted)*.8) # We keep 80% for the training set
train_index = X_sorted.index[:limit]
test_index = X_sorted.index[limit:]
x_train[feature_set] = X_dict[feature_set][X_dict[feature_set].index.isin(train_index)]
y_train[feature_set] = y_dict[feature_set][y_dict[feature_set].index.isin(train_index)]
x_test[feature_set] = X_dict[feature_set][X_dict[feature_set].index.isin(test_index)]
y_test[feature_set] = y_dict[feature_set][y_dict[feature_set].index.isin(test_index)]
34_resp num_34_resp cat_34_resp bin_34_resp lit_34_resp ola_34_resp 34_exp num_34_exp cat_34_exp bin_34_exp 34_multi num_34_multi cat_34_multi bin_34_multi lit_34_multi ola_34_multi 34_lithium2y num_34_lithium2y cat_34_lithium2y bin_34_lithium2y 37_resp num_37_resp cat_37_resp bin_37_resp lit_37_resp ola_37_resp 37_exp num_37_exp cat_37_exp bin_37_exp 37_multi num_37_multi cat_37_multi bin_37_multi lit_37_multi ola_37_multi 37_lithium2y num_37_lithium2y cat_37_lithium2y bin_37_lithium2y 13_resp num_13_resp cat_13_resp bin_13_resp lit_13_resp ola_13_resp 13_exp num_13_exp cat_13_exp bin_13_exp 13_multi num_13_multi cat_13_multi bin_13_multi lit_13_multi ola_13_multi 13_lithium2y num_13_lithium2y cat_13_lithium2y bin_13_lithium2y
from sklearn.metrics import roc_auc_score
def results(clf, X1, y, v):
X = X1
if isinstance(X1, pd.Series):
X = X1.to_frame()
y_pred = clf.predict(X)
y_score = clf.predict_proba(X)[:, 1]
target_names = ['No Response', 'Response']
multiclass = 'raise'
if '_multi' in v:
y_score = clf.predict_proba(X)
multiclass = 'ovr'
target_names = ['No Response', 'Equivocal', 'Response']
if '_exp' in v:
target_names = ['Lithium', 'Olanzapine']
if '_lithium2y' in v:
target_names = ['other', 'lithium2y']
#print(classification_report(y, y_pred, target_names=target_names))
result = classification_report(y, y_pred, target_names=target_names, output_dict=True)
# Compute confusion matrix
cm = confusion_matrix(y, y_pred)
#print(cm)
print('Balanced Accuracy:', result['macro avg']['recall'])
# For average and multi_class parameters, check doc:
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
try:
roc_auc_score_val = roc_auc_score(y, y_score, average='weighted', multi_class=multiclass)
except:
# 34_multi fails with Naive Bayes because MixedNB predict_proba() seems to return proba values that don't sum up to 1...
roc_auc_score_val = None
result['roc_auc'] = roc_auc_score_val
print('ROC_AUC score:', roc_auc_score_val)
# Show confusion matrix in a separate window
color_map = plt.cm.get_cmap('Blues')
plt.matshow(cm, cmap=color_map)
plt.title('Confusion matrix\n')
plt.colorbar()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
return(result)
import warnings
def evaluate(clf, x_test, y_test):
all_results = pd.DataFrame(columns=['features',
'balanced accuracy',
'accuracy',
'roc_auc',
'f1 (response)',
'f1 (equivocal)',
'f1 (no response)',
'f1 (lithium)',
'f1 (olanzapine)',
'f1 (lithium > 2y)',
'f1 (other)',
'f1 (weighted avg)'])
for v in clf:
print(40*'_' + v + 40*'_')
v2 = v.replace('_balanced_accuracy', '').replace('_accuracy', '').replace('_f1_weighted', '')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
result = results(clf[v], x_test[v2], y_test[v2], v)
result_dict = {'features': v,
'accuracy': result['accuracy'],
'roc_auc': result['roc_auc'],
'balanced accuracy': result['macro avg']['recall'],
'f1 (weighted avg)': result['weighted avg']['f1-score']
}
if '_exp' in v:
result_dict.update({
'f1 (lithium)': result['Lithium']['f1-score'],
'f1 (olanzapine)': result['Olanzapine']['f1-score'],
})
elif '_multi' in v:
result_dict.update({
'f1 (no response)': result['No Response']['f1-score'],
'f1 (equivocal)': result['Equivocal']['f1-score'],
'f1 (response)': result['Response']['f1-score']
})
elif '_lithium2y' in v:
result_dict.update({
'f1 (other)': result['other']['f1-score'],
'f1 (lithium > 2y)': result['lithium2y']['f1-score']
})
else:
result_dict.update({
'f1 (no response)': result['No Response']['f1-score'],
'f1 (response)': result['Response']['f1-score']
})
all_results = all_results.append(result_dict,
ignore_index=True)
print(83*'_')
print('\n\n\n')
return all_results
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LogisticRegressionCV
@ignore_warnings(category=ConvergenceWarning) # max_iter default value (=100?) triggers this warning
def run(X1, y, l1, l2):
X = X1
if isinstance(X1, pd.Series):
X = X1.to_frame()
return LogisticRegressionCV(penalty='elasticnet',
l1_ratios=[l1, l2],
cv=5,
solver='saga',
scoring='balanced_accuracy',
n_jobs=4).fit(X, y)
clf_multi_dict = dict()
# Elastic net regularisation for larger feature sets
l1 = .8
l2 = .2
for v in [f for f in X_dict if len(X_dict[f].columns) > 5]:
print(v)
clf_multi_dict[v] = run(x_train[v], y_train[v], l1, l2) # a lot of features (more than 5)
for v in [f for f in X_dict if len(X_dict[f].columns) <= 5]:
print(v)
clf_multi_dict[v] = run(x_train[v], y_train[v], l2, l1) # NOT a lot of features (5 or less)
34_resp cat_34_resp bin_34_resp lit_34_resp ola_34_resp 34_exp cat_34_exp bin_34_exp 34_multi cat_34_multi bin_34_multi lit_34_multi ola_34_multi 34_lithium2y cat_34_lithium2y bin_34_lithium2y 37_resp cat_37_resp bin_37_resp lit_37_resp ola_37_resp 37_exp cat_37_exp bin_37_exp 37_multi cat_37_multi bin_37_multi lit_37_multi ola_37_multi 37_lithium2y cat_37_lithium2y bin_37_lithium2y 13_resp cat_13_resp bin_13_resp lit_13_resp ola_13_resp 13_exp cat_13_exp bin_13_exp 13_multi cat_13_multi bin_13_multi lit_13_multi ola_13_multi 13_lithium2y cat_13_lithium2y bin_13_lithium2y num_34_resp num_34_exp num_34_multi num_34_lithium2y num_37_resp num_37_exp num_37_multi num_37_lithium2y num_13_resp num_13_exp num_13_multi num_13_lithium2y
all_results = evaluate(clf_multi_dict, x_test, y_test)
________________________________________34_resp________________________________________ Balanced Accuracy: 0.5807004496516058 ROC_AUC score: 0.6445436122012334
___________________________________________________________________________________ ________________________________________cat_34_resp________________________________________ Balanced Accuracy: 0.5206806844315283 ROC_AUC score: 0.5627429177011705
___________________________________________________________________________________ ________________________________________bin_34_resp________________________________________ Balanced Accuracy: 0.517216479216485 ROC_AUC score: 0.5665701765426024
___________________________________________________________________________________ ________________________________________lit_34_resp________________________________________ Balanced Accuracy: 0.576376784572234 ROC_AUC score: 0.6225357816631882
___________________________________________________________________________________ ________________________________________ola_34_resp________________________________________ Balanced Accuracy: 0.5623262698430254 ROC_AUC score: 0.6852710006119724
___________________________________________________________________________________ ________________________________________34_exp________________________________________ Balanced Accuracy: 0.5567292093418088 ROC_AUC score: 0.597591654042374
___________________________________________________________________________________ ________________________________________cat_34_exp________________________________________ Balanced Accuracy: 0.5528329793498448 ROC_AUC score: 0.585100734306033
___________________________________________________________________________________ ________________________________________bin_34_exp________________________________________ Balanced Accuracy: 0.5471165025089209 ROC_AUC score: 0.5855459922694032
___________________________________________________________________________________ ________________________________________34_multi________________________________________ Balanced Accuracy: 0.3857732604055686 ROC_AUC score: 0.6075739122604424
___________________________________________________________________________________ ________________________________________cat_34_multi________________________________________ Balanced Accuracy: 0.34700607363579766 ROC_AUC score: 0.5446438972107863
___________________________________________________________________________________ ________________________________________bin_34_multi________________________________________ Balanced Accuracy: 0.3461819988708128 ROC_AUC score: 0.5478197984652663
___________________________________________________________________________________ ________________________________________lit_34_multi________________________________________ Balanced Accuracy: 0.38412353361957013 ROC_AUC score: 0.5939621512078251
___________________________________________________________________________________ ________________________________________ola_34_multi________________________________________ Balanced Accuracy: 0.37496773392901467 ROC_AUC score: 0.6348224798138037
___________________________________________________________________________________ ________________________________________34_lithium2y________________________________________ Balanced Accuracy: 0.5138532834389381 ROC_AUC score: 0.6187926946544635
___________________________________________________________________________________ ________________________________________cat_34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________bin_34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________37_resp________________________________________ Balanced Accuracy: 0.5830973043786685 ROC_AUC score: 0.6499600977105525
___________________________________________________________________________________ ________________________________________cat_37_resp________________________________________ Balanced Accuracy: 0.5210064816192035 ROC_AUC score: 0.5626445206576584
___________________________________________________________________________________ ________________________________________bin_37_resp________________________________________ Balanced Accuracy: 0.517216479216485 ROC_AUC score: 0.5665701765426026
___________________________________________________________________________________ ________________________________________lit_37_resp________________________________________ Balanced Accuracy: 0.5786410686683147 ROC_AUC score: 0.6238457544171984
___________________________________________________________________________________ ________________________________________ola_37_resp________________________________________ Balanced Accuracy: 0.5715256733541256 ROC_AUC score: 0.6935510628414695
___________________________________________________________________________________ ________________________________________37_exp________________________________________ Balanced Accuracy: 0.5539152712838864 ROC_AUC score: 0.5994841445594751
___________________________________________________________________________________ ________________________________________cat_37_exp________________________________________ Balanced Accuracy: 0.5491100712496624 ROC_AUC score: 0.5844461374697395
___________________________________________________________________________________ ________________________________________bin_37_exp________________________________________ Balanced Accuracy: 0.5471165025089209 ROC_AUC score: 0.5855459398429387
___________________________________________________________________________________ ________________________________________37_multi________________________________________ Balanced Accuracy: 0.3875424076373301 ROC_AUC score: 0.6095569396838123
___________________________________________________________________________________ ________________________________________cat_37_multi________________________________________ Balanced Accuracy: 0.34708807903337985 ROC_AUC score: 0.5440369409509134
___________________________________________________________________________________ ________________________________________bin_37_multi________________________________________ Balanced Accuracy: 0.3461819988708128 ROC_AUC score: 0.5478197389143789
___________________________________________________________________________________ ________________________________________lit_37_multi________________________________________ Balanced Accuracy: 0.3850145106903992 ROC_AUC score: 0.5951377257494506
___________________________________________________________________________________ ________________________________________ola_37_multi________________________________________ Balanced Accuracy: 0.38150595681310495 ROC_AUC score: 0.6409861821257834
___________________________________________________________________________________ ________________________________________37_lithium2y________________________________________ Balanced Accuracy: 0.512551877672003 ROC_AUC score: 0.619842045721224
___________________________________________________________________________________ ________________________________________cat_37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________bin_37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________13_resp________________________________________ Balanced Accuracy: 0.5822266049587533 ROC_AUC score: 0.6433300963375705
___________________________________________________________________________________ ________________________________________cat_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5031357621766341
___________________________________________________________________________________ ________________________________________bin_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________lit_13_resp________________________________________ Balanced Accuracy: 0.576376784572234 ROC_AUC score: 0.6225498156535128
___________________________________________________________________________________ ________________________________________ola_13_resp________________________________________ Balanced Accuracy: 0.5519614457298325 ROC_AUC score: 0.6735587309311567
___________________________________________________________________________________ ________________________________________13_exp________________________________________ Balanced Accuracy: 0.519108292943985 ROC_AUC score: 0.5672112518097616
___________________________________________________________________________________ ________________________________________cat_13_exp________________________________________ Balanced Accuracy: 0.5054545542240632 ROC_AUC score: 0.5338737348445577
___________________________________________________________________________________ ________________________________________bin_13_exp________________________________________ Balanced Accuracy: 0.5003936178957946 ROC_AUC score: 0.5499176485094946
___________________________________________________________________________________ ________________________________________13_multi________________________________________ Balanced Accuracy: 0.3878394785017754 ROC_AUC score: 0.6070247192189846
___________________________________________________________________________________ ________________________________________cat_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5003737582196537
___________________________________________________________________________________ ________________________________________bin_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________lit_13_multi________________________________________ Balanced Accuracy: 0.38402994192011985 ROC_AUC score: 0.5939708088740229
___________________________________________________________________________________ ________________________________________ola_13_multi________________________________________ Balanced Accuracy: 0.36855795482750064 ROC_AUC score: 0.6250043455668536
___________________________________________________________________________________ ________________________________________13_lithium2y________________________________________ Balanced Accuracy: 0.5084324333976145 ROC_AUC score: 0.6024577644964135
___________________________________________________________________________________ ________________________________________cat_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________bin_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5
___________________________________________________________________________________ ________________________________________num_34_resp________________________________________ Balanced Accuracy: 0.5824969107904944 ROC_AUC score: 0.6438724242285558
___________________________________________________________________________________ ________________________________________num_34_exp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5455372076910043
___________________________________________________________________________________ ________________________________________num_34_multi________________________________________ Balanced Accuracy: 0.38732593956959604 ROC_AUC score: 0.6073454289552751
___________________________________________________________________________________ ________________________________________num_34_lithium2y________________________________________ Balanced Accuracy: 0.5045200853000296 ROC_AUC score: 0.6013688754374688
___________________________________________________________________________________ ________________________________________num_37_resp________________________________________ Balanced Accuracy: 0.5860582258784225 ROC_AUC score: 0.64826060342559
___________________________________________________________________________________ ________________________________________num_37_exp________________________________________ Balanced Accuracy: 0.5007556750599339 ROC_AUC score: 0.5560399061859873
___________________________________________________________________________________ ________________________________________num_37_multi________________________________________ Balanced Accuracy: 0.3902299311659596 ROC_AUC score: 0.6105212605735918
___________________________________________________________________________________ ________________________________________num_37_lithium2y________________________________________ Balanced Accuracy: 0.5076257308151458 ROC_AUC score: 0.6052576282816533
___________________________________________________________________________________ ________________________________________num_13_resp________________________________________ Balanced Accuracy: 0.5824969107904944 ROC_AUC score: 0.6438352392993214
___________________________________________________________________________________ ________________________________________num_13_exp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5454701018163882
___________________________________________________________________________________ ________________________________________num_13_multi________________________________________ Balanced Accuracy: 0.38732593956959604 ROC_AUC score: 0.6073560975471447
___________________________________________________________________________________ ________________________________________num_13_lithium2y________________________________________ Balanced Accuracy: 0.5044206026100176 ROC_AUC score: 0.6013454396114563
___________________________________________________________________________________
display(all_results)
| features | balanced accuracy | accuracy | roc_auc | f1 (response) | f1 (equivocal) | f1 (no response) | f1 (lithium) | f1 (olanzapine) | f1 (lithium > 2y) | f1 (other) | f1 (weighted avg) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 34_resp | 0.580700 | 0.600151 | 0.644544 | 0.416942 | NaN | 0.695752 | NaN | NaN | NaN | NaN | 0.566286 |
| 1 | cat_34_resp | 0.520681 | 0.551113 | 0.562743 | 0.162562 | NaN | 0.693378 | NaN | NaN | NaN | NaN | 0.446892 |
| 2 | bin_34_resp | 0.517216 | 0.548849 | 0.566570 | 0.131445 | NaN | 0.695287 | NaN | NaN | NaN | NaN | 0.433465 |
| 3 | lit_34_resp | 0.576377 | 0.561359 | 0.622536 | 0.527458 | NaN | 0.590722 | NaN | NaN | NaN | NaN | 0.555586 |
| 4 | ola_34_resp | 0.562326 | 0.635294 | 0.685271 | 0.297371 | NaN | 0.753734 | NaN | NaN | NaN | NaN | 0.570544 |
| 5 | 34_exp | 0.556729 | 0.522952 | 0.597592 | NaN | NaN | NaN | 0.564147 | 0.473156 | NaN | NaN | 0.510661 |
| 6 | cat_34_exp | 0.552833 | 0.517692 | 0.585101 | NaN | NaN | NaN | 0.562717 | 0.462331 | NaN | NaN | 0.503708 |
| 7 | bin_34_exp | 0.547117 | 0.510360 | 0.585546 | NaN | NaN | NaN | 0.560137 | 0.447879 | NaN | NaN | 0.494150 |
| 8 | 34_multi | 0.385773 | 0.509882 | 0.607574 | 0.394888 | 0.0 | 0.626716 | NaN | NaN | NaN | NaN | 0.442636 |
| 9 | cat_34_multi | 0.347006 | 0.469238 | 0.544644 | 0.158200 | 0.0 | 0.623879 | NaN | NaN | NaN | NaN | 0.347330 |
| 10 | bin_34_multi | 0.346182 | 0.468441 | 0.547820 | 0.150086 | 0.0 | 0.623939 | NaN | NaN | NaN | NaN | 0.344135 |
| 11 | lit_34_multi | 0.384124 | 0.471301 | 0.593962 | 0.491928 | 0.0 | 0.528982 | NaN | NaN | NaN | NaN | 0.427151 |
| 12 | ola_34_multi | 0.374968 | 0.542003 | 0.634822 | 0.290523 | 0.0 | 0.680728 | NaN | NaN | NaN | NaN | 0.447247 |
| 13 | 34_lithium2y | 0.513853 | 0.801084 | 0.618793 | NaN | NaN | NaN | NaN | NaN | 0.068657 | 0.888651 | 0.725541 |
| 14 | cat_34_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 15 | bin_34_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 16 | 37_resp | 0.583097 | 0.602603 | 0.649960 | 0.419719 | NaN | 0.697835 | NaN | NaN | NaN | NaN | 0.568691 |
| 17 | cat_37_resp | 0.521006 | 0.551867 | 0.562645 | 0.154448 | NaN | 0.695150 | NaN | NaN | NaN | NaN | 0.444074 |
| 18 | bin_37_resp | 0.517216 | 0.548849 | 0.566570 | 0.131445 | NaN | 0.695287 | NaN | NaN | NaN | NaN | 0.433465 |
| 19 | lit_37_resp | 0.578641 | 0.563247 | 0.623846 | 0.527891 | NaN | 0.593677 | NaN | NaN | NaN | NaN | 0.557141 |
| 20 | ola_37_resp | 0.571526 | 0.643765 | 0.693551 | 0.316170 | NaN | 0.759147 | NaN | NaN | NaN | NaN | 0.581331 |
| 21 | 37_exp | 0.553915 | 0.518489 | 0.599484 | NaN | NaN | NaN | 0.564006 | 0.462360 | NaN | NaN | 0.504256 |
| 22 | cat_37_exp | 0.549110 | 0.512432 | 0.584446 | NaN | NaN | NaN | 0.561685 | 0.450709 | NaN | NaN | 0.496451 |
| 23 | bin_37_exp | 0.547117 | 0.510360 | 0.585546 | NaN | NaN | NaN | 0.560137 | 0.447879 | NaN | NaN | 0.494150 |
| 24 | 37_multi | 0.387542 | 0.512113 | 0.609557 | 0.398541 | 0.0 | 0.628659 | NaN | NaN | NaN | NaN | 0.444973 |
| 25 | cat_37_multi | 0.347088 | 0.469398 | 0.544037 | 0.157211 | 0.0 | 0.624066 | NaN | NaN | NaN | NaN | 0.347023 |
| 26 | bin_37_multi | 0.346182 | 0.468441 | 0.547820 | 0.150086 | 0.0 | 0.623939 | NaN | NaN | NaN | NaN | 0.344135 |
| 27 | lit_37_multi | 0.385015 | 0.471827 | 0.595138 | 0.490578 | 0.0 | 0.530845 | NaN | NaN | NaN | NaN | 0.427218 |
| 28 | ola_37_multi | 0.381506 | 0.548869 | 0.640986 | 0.313316 | 0.0 | 0.685266 | NaN | NaN | NaN | NaN | 0.457389 |
| 29 | 37_lithium2y | 0.512552 | 0.800446 | 0.619842 | NaN | NaN | NaN | NaN | NaN | 0.064275 | 0.888314 | 0.724399 |
| 30 | cat_37_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 31 | bin_37_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 32 | 13_resp | 0.582227 | 0.599585 | 0.643330 | 0.439989 | NaN | 0.688390 | NaN | NaN | NaN | NaN | 0.573044 |
| 33 | cat_13_resp | 0.500000 | 0.535647 | 0.503136 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 34 | bin_13_resp | 0.500000 | 0.535647 | 0.500000 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 35 | lit_13_resp | 0.576377 | 0.561359 | 0.622550 | 0.527458 | NaN | 0.590722 | NaN | NaN | NaN | NaN | 0.555586 |
| 36 | ola_13_resp | 0.551961 | 0.626353 | 0.673559 | 0.272894 | NaN | 0.748575 | NaN | NaN | NaN | NaN | 0.557631 |
| 37 | 13_exp | 0.519108 | 0.441026 | 0.567211 | NaN | NaN | NaN | 0.586974 | 0.135568 | NaN | NaN | 0.321628 |
| 38 | cat_13_exp | 0.505455 | 0.421103 | 0.533874 | NaN | NaN | NaN | 0.583963 | 0.048717 | NaN | NaN | 0.269333 |
| 39 | bin_13_exp | 0.500394 | 0.413931 | 0.549918 | NaN | NaN | NaN | 0.582681 | 0.016056 | NaN | NaN | 0.249606 |
| 40 | 13_multi | 0.387839 | 0.510679 | 0.607025 | 0.417533 | 0.0 | 0.622216 | NaN | NaN | NaN | NaN | 0.449578 |
| 41 | cat_13_multi | 0.333333 | 0.456009 | 0.500374 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 42 | bin_13_multi | 0.333333 | 0.456009 | 0.500000 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 43 | lit_13_multi | 0.384030 | 0.471301 | 0.593971 | 0.492570 | 0.0 | 0.528489 | NaN | NaN | NaN | NaN | 0.427267 |
| 44 | ola_13_multi | 0.368558 | 0.535137 | 0.625004 | 0.265834 | 0.0 | 0.678005 | NaN | NaN | NaN | NaN | 0.437382 |
| 45 | 13_lithium2y | 0.508432 | 0.801084 | 0.602458 | NaN | NaN | NaN | NaN | NaN | 0.042945 | 0.889007 | 0.720712 |
| 46 | cat_13_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 47 | bin_13_lithium2y | 0.500000 | 0.801084 | 0.500000 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 48 | num_34_resp | 0.582497 | 0.599585 | 0.643872 | 0.442928 | NaN | 0.687472 | NaN | NaN | NaN | NaN | 0.573918 |
| 49 | num_34_exp | 0.500000 | 0.412177 | 0.545537 | NaN | NaN | NaN | 0.583747 | 0.000000 | NaN | NaN | 0.240607 |
| 50 | num_34_multi | 0.387326 | 0.509882 | 0.607345 | 0.418582 | 0.0 | 0.620580 | NaN | NaN | NaN | NaN | 0.449249 |
| 51 | num_34_lithium2y | 0.504520 | 0.800606 | 0.601369 | NaN | NaN | NaN | NaN | NaN | 0.024942 | 0.888948 | 0.717083 |
| 52 | num_37_resp | 0.586058 | 0.603169 | 0.648261 | 0.447479 | NaN | 0.690406 | NaN | NaN | NaN | NaN | 0.577602 |
| 53 | num_37_exp | 0.500756 | 0.413134 | 0.556040 | NaN | NaN | NaN | 0.584049 | 0.003788 | NaN | NaN | 0.242958 |
| 54 | num_37_multi | 0.390230 | 0.513548 | 0.610521 | 0.425006 | 0.0 | 0.623219 | NaN | NaN | NaN | NaN | 0.453004 |
| 55 | num_37_lithium2y | 0.507626 | 0.801721 | 0.605258 | NaN | NaN | NaN | NaN | NaN | 0.037152 | 0.889481 | 0.719939 |
| 56 | num_13_resp | 0.582497 | 0.599585 | 0.643835 | 0.442928 | NaN | 0.687472 | NaN | NaN | NaN | NaN | 0.573918 |
| 57 | num_13_exp | 0.500000 | 0.412177 | 0.545470 | NaN | NaN | NaN | 0.583747 | 0.000000 | NaN | NaN | 0.240607 |
| 58 | num_13_multi | 0.387326 | 0.509882 | 0.607356 | 0.418582 | 0.0 | 0.620580 | NaN | NaN | NaN | NaN | 0.449249 |
| 59 | num_13_lithium2y | 0.504421 | 0.800446 | 0.601345 | NaN | NaN | NaN | NaN | NaN | 0.024922 | 0.888849 | 0.717000 |
import shap
def shap_plot(clf, X, sample=50):
explainer = shap.KernelExplainer(clf.predict, shap.sample(X, sample))
shap_values = explainer.shap_values(shap.sample(X, sample), l1_reg="num_features("+str(X.shape[1])+")")
shap.summary_plot(shap_values, shap.sample(X, sample), feature_names=X.columns)
shap_plot(clf_multi_dict['13_lithium2y'], X_dict['13_lithium2y'], 100)
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
0%| | 0/100 [00:00<?, ?it/s]
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
Inspired by: https://github.com/lmcinnes/umap/issues/58 More precisely: https://github.com/lmcinnes/umap/issues/58#issuecomment-419682509
def umap_embedding(X_dict, n_neighbors=15, weight=0.5):
import umap.umap_ as umap
fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42).fit(X_dict['num'].values)
fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42).fit(X_dict['bin'].values)
intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
intersection = umap.reset_local_connectivity(intersection)
embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
fit1._initial_alpha, fit1._a, fit1._b,
fit1.repulsion_strength, fit1.negative_sample_rate,
200, 'random', np.random, fit1.metric,
fit1._metric_kwds, False,
densmap_kwds={}, output_dens=False)
return embedding
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})
def plot_umap(X_dict, n_neighbors, weight, min_dist=0.1, n_components=2):
import umap.umap_ as umap
import matplotlib.pyplot as plt
print('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components))
fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['num'].values)
fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['bin'].values)
intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
intersection = umap.reset_local_connectivity(intersection)
embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
fit1._initial_alpha, fit1._a, fit1._b,
fit1.repulsion_strength, fit1.negative_sample_rate,
200, 'random', np.random, fit1.metric,
fit1._metric_kwds, False,
densmap_kwds={}, output_dens=False
)
plt.clf()
fig = plt.figure()
if n_components == 3:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(
embedding[0][:, 0],
embedding[0][:, 1],
embedding[0][:, 2],
c=[sns.color_palette()[x] for x in X.exposure])
else:
ax = fig.add_subplot(111)
ax.scatter(
embedding[0][:, 0],
embedding[0][:, 1],
c=[sns.color_palette()[x] for x in X.exposure])
#plt.gca().set_aspect('equal', 'datalim')
plt.title('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components), fontsize=16)
plt.savefig('umap_images/n' + str(n_neighbors) + '_md{:.1f}_w{:.1f}_'.format(min_dist, weight) + str(n_components) + 'd.png')
def plot_umap_2y(X_dict, n_neighbors, weight, min_dist=0.1, n_components=2):
import umap.umap_ as umap
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
print('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components))
fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['num_13_lithium2y'].values)
fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['cat_13_lithium2y'].values)
intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
intersection = umap.reset_local_connectivity(intersection)
embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
fit1._initial_alpha, fit1._a, fit1._b,
fit1.repulsion_strength, fit1.negative_sample_rate,
200, 'random', np.random, fit1.metric,
fit1._metric_kwds, False,
densmap_kwds={}, output_dens=False
)
plt.clf()
fig = plt.figure(figsize=(36,18))
if n_components == 3:
ax = fig.add_subplot(111, projection='3d')
ax.scatter(embedding[0][:, 0],
embedding[0][:, 1],
embedding[0][:, 2],
c=[sns.color_palette()[col] for col in y_dict['13_lithium2y'].astype(int)])
classes = ['Lithium for more than 2 years', 'Other']
class_colours = sns.color_palette()
recs = []
for i in range(0, len(class_colours)):
recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=class_colours[i]))
ax.legend(recs, classes, loc=1)
else:
ax = fig.add_subplot(111)
ax.scatter(
embedding[0][:, 0],
embedding[0][:, 1],
c=[sns.color_palette()[x] for x in y_dict['13_lithium2y'].astype(int)])
#plt.gca().set_aspect('equal', 'datalim')
plt.title('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components), fontsize=16)
plt.savefig('umap_images/n' + str(n_neighbors) + '_md{:.1f}_w{:.1f}_'.format(min_dist, weight) + str(n_components) + 'd_2y.png')
return embedding
embedding = plot_umap_2y(X_dict, 50, 0, 0.1, 3)
n_neighbors=50 min_dist=0.1 weight=0.0 3D
gradient function is not yet implemented for jaccard distance metric; inverse_transform will be unavailable Failed to correctly find n_neighbors for some samples.Results may be less than ideal. Try re-running withdifferent parameters. A few of your vertices were disconnected from the manifold. This shouldn't cause problems. Disconnection_distance = 1 has removed 250 edges. It has only fully disconnected 5 vertices. Use umap.utils.disconnected_vertices() to identify them.
<Figure size 1008x720 with 0 Axes>
for n_neighbors in range(50, 250, 20):
for weight in np.linspace(0, 1, 5):
for n_components in range(2,4):
for min_dist in np.linspace(.1, .9, 4):
plot_umap_2y(X_dict, n_neighbors, weight, min_dist=min_dist, n_components=n_components)
print('done')
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
pca = PCA(n_components=9)
pipe = Pipeline([('pca', pca), ('logistic', clf_multi_dict['all34'])])
pipe.fit(X_dict['all34'], y)
#predictions = pipe.predict(X)
plt.figure(1, figsize=(8, 6))
plt.clf()
plt.plot(pca.explained_variance_, linewidth=2)
plt.axis('tight')
plt.xlabel('n components')
plt.ylabel('explained variance')
pd.set_option('display.float_format', lambda x: '%.3f' % x)
display(pd.DataFrame(pca.components_,columns=X_dict['all34'].columns))
from sklearn.ensemble import RandomForestClassifier
@ignore_warnings(category=ConvergenceWarning) # max_iter default value (=100?) triggers this warning
def run(X1, y):
X = X1
if isinstance(X1, pd.Series):
X = X1.to_frame()
return RandomForestClassifier(max_depth=2, random_state=0).fit(X, y)
clf2_multi_dict = dict()
for v in X_dict:
print(v)
clf2_multi_dict.update({v: run(X_dict[v], y_dict[v])})
34_resp num_34_resp cat_34_resp bin_34_resp lit_34_resp ola_34_resp 34_exp num_34_exp cat_34_exp bin_34_exp 34_multi num_34_multi cat_34_multi bin_34_multi lit_34_multi ola_34_multi 34_lithium2y num_34_lithium2y cat_34_lithium2y bin_34_lithium2y 37_resp num_37_resp cat_37_resp bin_37_resp lit_37_resp ola_37_resp 37_exp num_37_exp cat_37_exp bin_37_exp 37_multi num_37_multi cat_37_multi bin_37_multi lit_37_multi ola_37_multi 37_lithium2y num_37_lithium2y cat_37_lithium2y bin_37_lithium2y 13_resp num_13_resp cat_13_resp bin_13_resp lit_13_resp ola_13_resp 13_exp num_13_exp cat_13_exp bin_13_exp 13_multi num_13_multi cat_13_multi bin_13_multi lit_13_multi ola_13_multi 13_lithium2y num_13_lithium2y cat_13_lithium2y bin_13_lithium2y
all_results2 = evaluate(clf2_multi_dict, x_test, y_test)
________________________________________34_resp________________________________________ Balanced Accuracy: 0.5110331975606687 ROC_AUC score: 0.6441694746055537
___________________________________________________________________________________ ________________________________________num_34_resp________________________________________ Balanced Accuracy: 0.5773642464045033 ROC_AUC score: 0.6390544015514696
___________________________________________________________________________________ ________________________________________cat_34_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5785016761821947
___________________________________________________________________________________ ________________________________________bin_34_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5758893633940115
___________________________________________________________________________________ ________________________________________lit_34_resp________________________________________ Balanced Accuracy: 0.6069482286096927 ROC_AUC score: 0.6300764451501537
___________________________________________________________________________________ ________________________________________ola_34_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6840147979384636
___________________________________________________________________________________ ________________________________________34_exp________________________________________ Balanced Accuracy: 0.508012755987941 ROC_AUC score: 0.582772370162715
___________________________________________________________________________________ ________________________________________num_34_exp________________________________________ Balanced Accuracy: 0.5187031412259908 ROC_AUC score: 0.5527076800995852
___________________________________________________________________________________ ________________________________________cat_34_exp________________________________________ Balanced Accuracy: 0.509346904657651 ROC_AUC score: 0.5603262939270862
___________________________________________________________________________________ ________________________________________bin_34_exp________________________________________ Balanced Accuracy: 0.5144463220108947 ROC_AUC score: 0.5749915488539155
___________________________________________________________________________________ ________________________________________34_multi________________________________________ Balanced Accuracy: 0.34516011062967616 ROC_AUC score: 0.611353800587708
___________________________________________________________________________________ ________________________________________num_34_multi________________________________________ Balanced Accuracy: 0.3847279954832512 ROC_AUC score: 0.605524634439378
___________________________________________________________________________________ ________________________________________cat_34_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.56156310061291
___________________________________________________________________________________ ________________________________________bin_34_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5625488160989097
___________________________________________________________________________________ ________________________________________lit_34_multi________________________________________ Balanced Accuracy: 0.40293891192888515 ROC_AUC score: 0.6005018451992608
___________________________________________________________________________________ ________________________________________ola_34_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.6420242532566227
___________________________________________________________________________________ ________________________________________34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6211246390563939
___________________________________________________________________________________ ________________________________________num_34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6013454396114563
___________________________________________________________________________________ ________________________________________cat_34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5904663538063608
___________________________________________________________________________________ ________________________________________bin_34_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5786242468650198
___________________________________________________________________________________ ________________________________________37_resp________________________________________ Balanced Accuracy: 0.5845093305568586 ROC_AUC score: 0.6580121365888261
___________________________________________________________________________________ ________________________________________num_37_resp________________________________________ Balanced Accuracy: 0.6025075514010136 ROC_AUC score: 0.6566489513849956
___________________________________________________________________________________ ________________________________________cat_37_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5794025526023729
___________________________________________________________________________________ ________________________________________bin_37_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5777174317227491
___________________________________________________________________________________ ________________________________________lit_37_resp________________________________________ Balanced Accuracy: 0.601644382694887 ROC_AUC score: 0.6358305816688019
___________________________________________________________________________________ ________________________________________ola_37_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6925096035450168
___________________________________________________________________________________ ________________________________________37_exp________________________________________ Balanced Accuracy: 0.5125252066441527 ROC_AUC score: 0.581302279670443
___________________________________________________________________________________ ________________________________________num_37_exp________________________________________ Balanced Accuracy: 0.519850861387783 ROC_AUC score: 0.5650832616139299
___________________________________________________________________________________ ________________________________________cat_37_exp________________________________________ Balanced Accuracy: 0.5111294044521393 ROC_AUC score: 0.5589588020259263
___________________________________________________________________________________ ________________________________________bin_37_exp________________________________________ Balanced Accuracy: 0.5129434649782829 ROC_AUC score: 0.5686375662041394
___________________________________________________________________________________ ________________________________________37_multi________________________________________ Balanced Accuracy: 0.38740060468928744 ROC_AUC score: 0.6205182267712338
___________________________________________________________________________________ ________________________________________num_37_multi________________________________________ Balanced Accuracy: 0.4014441459086413 ROC_AUC score: 0.6173249209650806
___________________________________________________________________________________ ________________________________________cat_37_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5630210871861139
___________________________________________________________________________________ ________________________________________bin_37_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.56365325092559
___________________________________________________________________________________ ________________________________________lit_37_multi________________________________________ Balanced Accuracy: 0.4018588955331041 ROC_AUC score: 0.6062636060009988
___________________________________________________________________________________ ________________________________________ola_37_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.6479986686804441
___________________________________________________________________________________ ________________________________________37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6199223971246952
___________________________________________________________________________________ ________________________________________num_37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.60939492842348
___________________________________________________________________________________ ________________________________________cat_37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5880130054485904
___________________________________________________________________________________ ________________________________________bin_37_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5778896851755487
___________________________________________________________________________________ ________________________________________13_resp________________________________________ Balanced Accuracy: 0.5477520280088328 ROC_AUC score: 0.637369066143408
___________________________________________________________________________________ ________________________________________num_13_resp________________________________________ Balanced Accuracy: 0.5788939771856157 ROC_AUC score: 0.6401284024210251
___________________________________________________________________________________ ________________________________________cat_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5358554965046167
___________________________________________________________________________________ ________________________________________bin_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5100019593597326
___________________________________________________________________________________ ________________________________________lit_13_resp________________________________________ Balanced Accuracy: 0.5995958210786525 ROC_AUC score: 0.6204130403838095
___________________________________________________________________________________ ________________________________________ola_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.674893734285946
___________________________________________________________________________________ ________________________________________13_exp________________________________________ Balanced Accuracy: 0.5095903731589922 ROC_AUC score: 0.552450056452817
___________________________________________________________________________________ ________________________________________num_13_exp________________________________________ Balanced Accuracy: 0.5182964167140602 ROC_AUC score: 0.5508419270794014
___________________________________________________________________________________ ________________________________________cat_13_exp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5219919057732861
___________________________________________________________________________________ ________________________________________bin_13_exp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5737315312050705
___________________________________________________________________________________ ________________________________________13_multi________________________________________ Balanced Accuracy: 0.36479965529680997 ROC_AUC score: 0.605131810802567
___________________________________________________________________________________ ________________________________________num_13_multi________________________________________ Balanced Accuracy: 0.38581131764253085 ROC_AUC score: 0.6047799299462205
___________________________________________________________________________________ ________________________________________cat_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5304525184350798
___________________________________________________________________________________ ________________________________________bin_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5108637439107558
___________________________________________________________________________________ ________________________________________lit_13_multi________________________________________ Balanced Accuracy: 0.3999337699531935 ROC_AUC score: 0.593113733735595
___________________________________________________________________________________ ________________________________________ola_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.6325575057349074
___________________________________________________________________________________ ________________________________________13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6051245861265011
___________________________________________________________________________________ ________________________________________num_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.6015531734978113
___________________________________________________________________________________ ________________________________________cat_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5407494809044252
___________________________________________________________________________________ ________________________________________bin_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5400218543063251
___________________________________________________________________________________
display(all_results2)
| features | balanced accuracy | accuracy | roc_auc | f1 (response) | f1 (equivocal) | f1 (no response) | f1 (lithium) | f1 (olanzapine) | f1 (lithium > 2y) | f1 (other) | f1 (weighted avg) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 34_resp | 0.511033 | 0.545266 | 0.644169 | 0.059306 | NaN | 0.700162 | NaN | NaN | NaN | NaN | 0.402578 |
| 1 | num_34_resp | 0.577364 | 0.594115 | 0.639054 | 0.439291 | NaN | 0.681939 | NaN | NaN | NaN | NaN | 0.569265 |
| 2 | cat_34_resp | 0.500000 | 0.535647 | 0.578502 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 3 | bin_34_resp | 0.500000 | 0.535647 | 0.575889 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 4 | lit_34_resp | 0.606948 | 0.610132 | 0.630076 | 0.644272 | NaN | 0.568743 | NaN | NaN | NaN | NaN | 0.610691 |
| 5 | ola_34_resp | 0.500000 | 0.598588 | 0.684015 | 0.000000 | NaN | 0.748896 | NaN | NaN | NaN | NaN | 0.448280 |
| 6 | 34_exp | 0.508013 | 0.424450 | 0.582772 | NaN | NaN | NaN | 0.584895 | 0.061834 | NaN | NaN | 0.277428 |
| 7 | num_34_exp | 0.518703 | 0.436882 | 0.552708 | NaN | NaN | NaN | 0.590377 | 0.099414 | NaN | NaN | 0.301778 |
| 8 | cat_34_exp | 0.509347 | 0.429551 | 0.560326 | NaN | NaN | NaN | 0.582039 | 0.101882 | NaN | NaN | 0.299792 |
| 9 | bin_34_exp | 0.514446 | 0.440708 | 0.574992 | NaN | NaN | NaN | 0.579307 | 0.165914 | NaN | NaN | 0.336305 |
| 10 | 34_multi | 0.345160 | 0.469238 | 0.611354 | 0.093006 | 0.0 | 0.630719 | NaN | NaN | NaN | NaN | 0.324555 |
| 11 | num_34_multi | 0.384728 | 0.505738 | 0.605525 | 0.421562 | 0.0 | 0.614419 | NaN | NaN | NaN | NaN | 0.447623 |
| 12 | cat_34_multi | 0.333333 | 0.456009 | 0.561563 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 13 | bin_34_multi | 0.333333 | 0.456009 | 0.562549 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 14 | lit_34_multi | 0.402939 | 0.512112 | 0.600502 | 0.597719 | 0.0 | 0.505909 | NaN | NaN | NaN | NaN | 0.467889 |
| 15 | ola_34_multi | 0.333333 | 0.510501 | 0.642024 | 0.000000 | 0.0 | 0.675936 | NaN | NaN | NaN | NaN | 0.345066 |
| 16 | 34_lithium2y | 0.500000 | 0.801084 | 0.621125 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 17 | num_34_lithium2y | 0.500000 | 0.801084 | 0.601345 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 18 | cat_34_lithium2y | 0.500000 | 0.801084 | 0.590466 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 19 | bin_34_lithium2y | 0.500000 | 0.801084 | 0.578624 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 20 | 37_resp | 0.584509 | 0.605998 | 0.658012 | 0.400230 | NaN | 0.706642 | NaN | NaN | NaN | NaN | 0.564359 |
| 21 | num_37_resp | 0.602508 | 0.618069 | 0.656649 | 0.483023 | NaN | 0.697174 | NaN | NaN | NaN | NaN | 0.597732 |
| 22 | cat_37_resp | 0.500000 | 0.535647 | 0.579403 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 23 | bin_37_resp | 0.500000 | 0.535647 | 0.577717 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 24 | lit_37_resp | 0.601644 | 0.606671 | 0.635831 | 0.646293 | NaN | 0.557052 | NaN | NaN | NaN | NaN | 0.606615 |
| 25 | ola_37_resp | 0.500000 | 0.598588 | 0.692510 | 0.000000 | NaN | 0.748896 | NaN | NaN | NaN | NaN | 0.448280 |
| 26 | 37_exp | 0.512525 | 0.433695 | 0.581302 | NaN | NaN | NaN | 0.583226 | 0.116828 | NaN | NaN | 0.309067 |
| 27 | num_37_exp | 0.519851 | 0.443258 | 0.565083 | NaN | NaN | NaN | 0.585990 | 0.150328 | NaN | NaN | 0.329898 |
| 28 | cat_37_exp | 0.511129 | 0.431782 | 0.558959 | NaN | NaN | NaN | 0.582797 | 0.109418 | NaN | NaN | 0.304534 |
| 29 | bin_37_exp | 0.512943 | 0.436564 | 0.568638 | NaN | NaN | NaN | 0.581012 | 0.140112 | NaN | NaN | 0.321841 |
| 30 | 37_multi | 0.387401 | 0.513548 | 0.620518 | 0.379188 | 0.0 | 0.634992 | NaN | NaN | NaN | NaN | 0.440174 |
| 31 | num_37_multi | 0.401444 | 0.526458 | 0.617325 | 0.458353 | 0.0 | 0.630303 | NaN | NaN | NaN | NaN | 0.469479 |
| 32 | cat_37_multi | 0.333333 | 0.456009 | 0.563021 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 33 | bin_37_multi | 0.333333 | 0.456009 | 0.563653 | 0.000000 | 0.0 | 0.626451 | NaN | NaN | NaN | NaN | 0.285667 |
| 34 | lit_37_multi | 0.401859 | 0.511058 | 0.606264 | 0.596645 | 0.0 | 0.504175 | NaN | NaN | NaN | NaN | 0.466740 |
| 35 | ola_37_multi | 0.333333 | 0.510501 | 0.647999 | 0.000000 | 0.0 | 0.675936 | NaN | NaN | NaN | NaN | 0.345066 |
| 36 | 37_lithium2y | 0.500000 | 0.801084 | 0.619922 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 37 | num_37_lithium2y | 0.500000 | 0.801084 | 0.609395 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 38 | cat_37_lithium2y | 0.500000 | 0.801084 | 0.588013 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 39 | bin_37_lithium2y | 0.500000 | 0.801084 | 0.577890 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 40 | 13_resp | 0.547752 | 0.572614 | 0.637369 | 0.301910 | NaN | 0.692036 | NaN | NaN | NaN | NaN | 0.510880 |
| 41 | num_13_resp | 0.578894 | 0.595436 | 0.640128 | 0.443291 | NaN | 0.682269 | NaN | NaN | NaN | NaN | 0.571299 |
| 42 | cat_13_resp | 0.500000 | 0.535647 | 0.535855 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 43 | bin_13_resp | 0.500000 | 0.535647 | 0.510002 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 44 | lit_13_resp | 0.599596 | 0.606042 | 0.620413 | 0.649692 | NaN | 0.549964 | NaN | NaN | NaN | NaN | 0.605351 |
| 45 | ola_13_resp | 0.500000 | 0.598588 | 0.674894 | 0.000000 | NaN | 0.748896 | NaN | NaN | NaN | NaN | 0.448280 |
| 46 | 13_exp | 0.509590 | 0.424131 | 0.552450 | NaN | NaN | NaN | 0.587792 | 0.044938 | NaN | NaN | 0.268690 |
| 47 | num_13_exp | 0.518296 | 0.436404 | 0.550842 | NaN | NaN | NaN | 0.590172 | 0.097959 | NaN | NaN | 0.300838 |
| 48 | cat_13_exp | 0.500000 | 0.412177 | 0.521992 | NaN | NaN | NaN | 0.583747 | 0.000000 | NaN | NaN | 0.240607 |
| 49 | bin_13_exp | 0.500000 | 0.412177 | 0.573732 | NaN | NaN | NaN | 0.583747 | 0.000000 | NaN | NaN | 0.240607 |
| 50 | 13_multi | 0.364800 | 0.487727 | 0.605132 | 0.285375 | 0.0 | 0.624471 | NaN | NaN | NaN | NaN | 0.398113 |
| 51 | num_13_multi | 0.385811 | 0.507810 | 0.604780 | 0.417036 | 0.0 | 0.618374 | NaN | NaN | NaN | NaN | 0.447629 |
| 52 | cat_13_multi | 0.333333 | 0.456009 | 0.530453 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 53 | bin_13_multi | 0.333333 | 0.456009 | 0.510864 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 54 | lit_13_multi | 0.399934 | 0.509215 | 0.593114 | 0.596050 | 0.0 | 0.499676 | NaN | NaN | NaN | NaN | 0.464782 |
| 55 | ola_13_multi | 0.333333 | 0.510501 | 0.632558 | 0.000000 | 0.0 | 0.675936 | NaN | NaN | NaN | NaN | 0.345066 |
| 56 | 13_lithium2y | 0.500000 | 0.801084 | 0.605125 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 57 | num_13_lithium2y | 0.500000 | 0.801084 | 0.601553 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 58 | cat_13_lithium2y | 0.500000 | 0.801084 | 0.540749 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
| 59 | bin_13_lithium2y | 0.500000 | 0.801084 | 0.540022 | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.889558 | 0.712610 |
for v in clf2_multi_dict.keys():
importances = clf2_multi_dict[v].feature_importances_
indices = np.argsort(importances)
features = X_dict[v].columns
plt.figure(1, figsize=(8, 10))
plt.clf()
plt.title('Feature Importances ' + v)
plt.barh(range(len(indices)), importances[indices], color='b', align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.show()
import shap
for v in clf2_multi_dict.keys():
print(40*'_' + v + 40*'_')
shap_values = shap.TreeExplainer(clf2_multi_dict[v]).shap_values(X_dict[v])
shap.summary_plot(shap_values, X_dict[v], plot_type="bar")
print(83*'_')
print('\n\n\n')
________________________________________34_resp________________________________________
___________________________________________________________________________________ ________________________________________num_34_resp________________________________________
___________________________________________________________________________________ ________________________________________cat_34_resp________________________________________
___________________________________________________________________________________ ________________________________________bin_34_resp________________________________________
___________________________________________________________________________________ ________________________________________lit_34_resp________________________________________
___________________________________________________________________________________ ________________________________________ola_34_resp________________________________________
___________________________________________________________________________________ ________________________________________34_exp________________________________________
___________________________________________________________________________________ ________________________________________num_34_exp________________________________________
___________________________________________________________________________________ ________________________________________cat_34_exp________________________________________
___________________________________________________________________________________ ________________________________________bin_34_exp________________________________________
___________________________________________________________________________________ ________________________________________34_multi________________________________________
___________________________________________________________________________________ ________________________________________num_34_multi________________________________________
___________________________________________________________________________________ ________________________________________cat_34_multi________________________________________
___________________________________________________________________________________ ________________________________________bin_34_multi________________________________________
___________________________________________________________________________________ ________________________________________lit_34_multi________________________________________
___________________________________________________________________________________ ________________________________________ola_34_multi________________________________________
___________________________________________________________________________________ ________________________________________34_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________num_34_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________cat_34_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________bin_34_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________37_resp________________________________________
___________________________________________________________________________________ ________________________________________num_37_resp________________________________________
___________________________________________________________________________________ ________________________________________cat_37_resp________________________________________
___________________________________________________________________________________ ________________________________________bin_37_resp________________________________________
___________________________________________________________________________________ ________________________________________lit_37_resp________________________________________
___________________________________________________________________________________ ________________________________________ola_37_resp________________________________________
___________________________________________________________________________________ ________________________________________37_exp________________________________________
___________________________________________________________________________________ ________________________________________num_37_exp________________________________________
___________________________________________________________________________________ ________________________________________cat_37_exp________________________________________
___________________________________________________________________________________ ________________________________________bin_37_exp________________________________________
___________________________________________________________________________________ ________________________________________37_multi________________________________________
___________________________________________________________________________________ ________________________________________num_37_multi________________________________________
___________________________________________________________________________________ ________________________________________cat_37_multi________________________________________
___________________________________________________________________________________ ________________________________________bin_37_multi________________________________________
___________________________________________________________________________________ ________________________________________lit_37_multi________________________________________
___________________________________________________________________________________ ________________________________________ola_37_multi________________________________________
___________________________________________________________________________________ ________________________________________37_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________num_37_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________cat_37_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________bin_37_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________13_resp________________________________________
___________________________________________________________________________________ ________________________________________num_13_resp________________________________________
___________________________________________________________________________________ ________________________________________cat_13_resp________________________________________
___________________________________________________________________________________ ________________________________________bin_13_resp________________________________________
___________________________________________________________________________________ ________________________________________lit_13_resp________________________________________
___________________________________________________________________________________ ________________________________________ola_13_resp________________________________________
___________________________________________________________________________________ ________________________________________13_exp________________________________________
___________________________________________________________________________________ ________________________________________num_13_exp________________________________________
___________________________________________________________________________________ ________________________________________cat_13_exp________________________________________
___________________________________________________________________________________ ________________________________________bin_13_exp________________________________________
___________________________________________________________________________________ ________________________________________13_multi________________________________________
___________________________________________________________________________________ ________________________________________num_13_multi________________________________________
___________________________________________________________________________________ ________________________________________cat_13_multi________________________________________
___________________________________________________________________________________ ________________________________________bin_13_multi________________________________________
___________________________________________________________________________________ ________________________________________lit_13_multi________________________________________
___________________________________________________________________________________ ________________________________________ola_13_multi________________________________________
___________________________________________________________________________________ ________________________________________13_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________num_13_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________cat_13_lithium2y________________________________________
___________________________________________________________________________________ ________________________________________bin_13_lithium2y________________________________________
___________________________________________________________________________________
from sklearn.naive_bayes import GaussianNB, BernoulliNB
from mixed_naive_bayes import MixedNB
clf3_dict = dict()
for v in [feats for feats in X_dict if ('_13' not in feats) and ('_3' not in feats)]:
print(v)
gnb = GaussianNB()
v_num = 'num_' + v
clf3_dict[v_num] = gnb.fit(X_dict[v_num], y_dict[v_num])
bnb = BernoulliNB()
v_bin = 'bin_' + v
clf3_dict[v_bin] = bnb.fit(X_dict[v_bin], y_dict[v_bin])
bin_features = [X_dict[v].columns.get_loc(c) for c in list(X_dict['bin_' + v].columns)]
mnb = MixedNB(categorical_features=bin_features)
clf3_dict[v] = mnb.fit(np.array(X_dict[v]), np.array(y_dict[v]))
34_resp [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 34_exp [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 34_multi [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 34_lithium2y [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 37_resp [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 37_exp [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 37_multi [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 37_lithium2y [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 13_resp [2 2 2 2 2 2 2 2] 13_exp [2 2 2 2 2 2 2 2] 13_multi [2 2 2 2 2 2 2 2] 13_lithium2y [2 2 2 2 2 2 2 2]
all_results3 = evaluate(clf3_dict, x_test, y_test)
________________________________________num_34_resp________________________________________ Balanced Accuracy: 0.5857668962597682 ROC_AUC score: 0.6380105061726984
___________________________________________________________________________________ ________________________________________bin_34_resp________________________________________ Balanced Accuracy: 0.5367480063157173 ROC_AUC score: 0.5760540497248314
___________________________________________________________________________________ ________________________________________34_resp________________________________________ Balanced Accuracy: 0.5952593505795128 ROC_AUC score: 0.50115180318303
___________________________________________________________________________________ ________________________________________num_34_exp________________________________________ Balanced Accuracy: 0.5224358006485783 ROC_AUC score: 0.540462221070238
___________________________________________________________________________________ ________________________________________bin_34_exp________________________________________ Balanced Accuracy: 0.5493286896068099 ROC_AUC score: 0.5724976219355683
___________________________________________________________________________________ ________________________________________34_exp________________________________________ Balanced Accuracy: 0.5580377738968214 ROC_AUC score: 0.49567450211635156
___________________________________________________________________________________ ________________________________________num_34_multi________________________________________ Balanced Accuracy: 0.39002599299933854 ROC_AUC score: 0.6033948361067034
___________________________________________________________________________________ ________________________________________bin_34_multi________________________________________ Balanced Accuracy: 0.3575640114310101 ROC_AUC score: 0.5615403377461362
___________________________________________________________________________________ ________________________________________34_multi________________________________________ Balanced Accuracy: 0.3965460299008324 ROC_AUC score: None
___________________________________________________________________________________ ________________________________________num_34_lithium2y________________________________________ Balanced Accuracy: 0.5325429561153795 ROC_AUC score: 0.5977664541818442
___________________________________________________________________________________ ________________________________________bin_34_lithium2y________________________________________ Balanced Accuracy: 0.5004006410256411 ROC_AUC score: 0.584060401935576
___________________________________________________________________________________ ________________________________________34_lithium2y________________________________________ Balanced Accuracy: 0.5462642336771864 ROC_AUC score: 0.5140489008438173
___________________________________________________________________________________ ________________________________________num_37_resp________________________________________ Balanced Accuracy: 0.5951753984508186 ROC_AUC score: 0.6473412775597532
___________________________________________________________________________________ ________________________________________bin_37_resp________________________________________ Balanced Accuracy: 0.5367480063157173 ROC_AUC score: 0.5760540497248314
___________________________________________________________________________________ ________________________________________37_resp________________________________________ Balanced Accuracy: 0.6166868800128145 ROC_AUC score: 0.5057420681685564
___________________________________________________________________________________ ________________________________________num_37_exp________________________________________ Balanced Accuracy: 0.5224040302110647 ROC_AUC score: 0.5437385605454366
___________________________________________________________________________________ ________________________________________bin_37_exp________________________________________ Balanced Accuracy: 0.5493286896068099 ROC_AUC score: 0.5724976219355683
___________________________________________________________________________________ ________________________________________37_exp________________________________________ Balanced Accuracy: 0.5421201555849703 ROC_AUC score: 0.5031892066911268
___________________________________________________________________________________ ________________________________________num_37_multi________________________________________ Balanced Accuracy: 0.39620903166493027 ROC_AUC score: 0.610487813737447
___________________________________________________________________________________ ________________________________________bin_37_multi________________________________________ Balanced Accuracy: 0.3575640114310101 ROC_AUC score: 0.5615403377461362
___________________________________________________________________________________ ________________________________________37_multi________________________________________ Balanced Accuracy: 0.4107821949731159 ROC_AUC score: None
___________________________________________________________________________________ ________________________________________num_37_lithium2y________________________________________ Balanced Accuracy: 0.5497151989143633 ROC_AUC score: 0.6024776929198934
___________________________________________________________________________________ ________________________________________bin_37_lithium2y________________________________________ Balanced Accuracy: 0.5004006410256411 ROC_AUC score: 0.584060401935576
___________________________________________________________________________________ ________________________________________37_lithium2y________________________________________ Balanced Accuracy: 0.5824579175467058 ROC_AUC score: 0.5084447093064781
___________________________________________________________________________________ ________________________________________num_13_resp________________________________________ Balanced Accuracy: 0.5857668962597682 ROC_AUC score: 0.6379753235088843
___________________________________________________________________________________ ________________________________________bin_13_resp________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5130504084621457
___________________________________________________________________________________ ________________________________________13_resp________________________________________ Balanced Accuracy: 0.5832612041052162 ROC_AUC score: 0.4957072144483473
___________________________________________________________________________________ ________________________________________num_13_exp________________________________________ Balanced Accuracy: 0.5224358006485783 ROC_AUC score: 0.5404592851882236
___________________________________________________________________________________ ________________________________________bin_13_exp________________________________________ Balanced Accuracy: 0.5025319885316061 ROC_AUC score: 0.5606965296196942
___________________________________________________________________________________ ________________________________________13_exp________________________________________ Balanced Accuracy: 0.539042931822109 ROC_AUC score: 0.47635325287338964
___________________________________________________________________________________ ________________________________________num_13_multi________________________________________ Balanced Accuracy: 0.39002599299933854 ROC_AUC score: 0.6033941180690108
___________________________________________________________________________________ ________________________________________bin_13_multi________________________________________ Balanced Accuracy: 0.3333333333333333 ROC_AUC score: 0.5108453613008589
___________________________________________________________________________________ ________________________________________13_multi________________________________________ Balanced Accuracy: 0.38824810288881545 ROC_AUC score: None
___________________________________________________________________________________ ________________________________________num_13_lithium2y________________________________________ Balanced Accuracy: 0.5325429561153795 ROC_AUC score: 0.5977613525054333
___________________________________________________________________________________ ________________________________________bin_13_lithium2y________________________________________ Balanced Accuracy: 0.5 ROC_AUC score: 0.5417565837134082
___________________________________________________________________________________ ________________________________________13_lithium2y________________________________________ Balanced Accuracy: 0.5359470497005316 ROC_AUC score: 0.5166446975726223
___________________________________________________________________________________
roc_auc_score returns an error for 34_multi. Looks like predict_proba() for MixedNB returns probas that don't sum up to 1. Strange.
v = '34_multi'
v2 = v
y_score = clf3_dict[v].predict_proba(x_test[v2])
print(y_score[0])
roc_auc_score(y_test[v2], y_score, average='weighted', multi_class='ovr')
#result = results(clf3_dict[v], x_test[v2], y_test[v2], v)
[4.16702351e-06 1.29765617e-06 3.37361391e-06]
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Input In [91], in <cell line: 5>() 3 y_score = clf3_dict[v].predict_proba(x_test[v2]) 4 print(y_score[0]) ----> 5 roc_auc_score(y_test[v2], y_score, average='weighted', multi_class='ovr') File ~/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/sklearn/metrics/_ranking.py:561, in roc_auc_score(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels) 559 if multi_class == "raise": 560 raise ValueError("multi_class must be in ('ovo', 'ovr')") --> 561 return _multiclass_roc_auc_score( 562 y_true, y_score, labels, multi_class, average, sample_weight 563 ) 564 elif y_type == "binary": 565 labels = np.unique(y_true) File ~/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/sklearn/metrics/_ranking.py:628, in _multiclass_roc_auc_score(y_true, y_score, labels, multi_class, average, sample_weight) 626 # validation of the input y_score 627 if not np.allclose(1, y_score.sum(axis=1)): --> 628 raise ValueError( 629 "Target scores need to be probabilities for multiclass " 630 "roc_auc, i.e. they should sum up to 1.0 over classes" 631 ) 633 # validation for multiclass parameter specifications 634 average_options = ("macro", "weighted") ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes
display(all_results3)
| features | balanced accuracy | accuracy | roc_auc | f1 (response) | f1 (equivocal) | f1 (no response) | f1 (lithium) | f1 (olanzapine) | f1 (lithium > 2y) | f1 (other) | f1 (weighted avg) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | num_34_resp | 0.585767 | 0.599208 | 0.638011 | 0.479294 | NaN | 0.674230 | NaN | NaN | NaN | NaN | 0.583711 |
| 1 | bin_34_resp | 0.536748 | 0.564127 | 0.576054 | 0.245511 | NaN | 0.693542 | NaN | NaN | NaN | NaN | 0.485497 |
| 2 | 34_resp | 0.595259 | 0.608827 | 0.501152 | 0.490167 | NaN | 0.682681 | NaN | NaN | NaN | NaN | 0.593286 |
| 3 | num_34_exp | 0.522436 | 0.472107 | 0.540462 | NaN | NaN | NaN | 0.558164 | 0.344418 | NaN | NaN | 0.432519 |
| 4 | bin_34_exp | 0.549329 | 0.529806 | 0.572498 | NaN | NaN | NaN | 0.536601 | 0.522808 | NaN | NaN | 0.528493 |
| 5 | 34_exp | 0.558038 | 0.546701 | 0.495675 | NaN | NaN | NaN | 0.531003 | 0.561382 | NaN | NaN | 0.548860 |
| 6 | num_34_multi | 0.390026 | 0.510201 | 0.603395 | 0.451128 | 0.0 | 0.610942 | NaN | NaN | NaN | NaN | 0.457781 |
| 7 | bin_34_multi | 0.357564 | 0.480236 | 0.56154 | 0.236620 | 0.0 | 0.625000 | NaN | NaN | NaN | NaN | 0.378990 |
| 8 | 34_multi | 0.396546 | 0.518648 | None | 0.462348 | 0.0 | 0.618352 | NaN | NaN | NaN | NaN | 0.465616 |
| 9 | num_34_lithium2y | 0.532543 | 0.762512 | 0.597766 | NaN | NaN | NaN | NaN | NaN | 0.201501 | 0.860513 | 0.729425 |
| 10 | bin_34_lithium2y | 0.500401 | 0.801243 | 0.58406 | NaN | NaN | NaN | NaN | NaN | 0.001601 | 0.889636 | 0.712992 |
| 11 | 34_lithium2y | 0.546264 | 0.768091 | 0.514049 | NaN | NaN | NaN | NaN | NaN | 0.233807 | 0.863367 | 0.738138 |
| 12 | num_37_resp | 0.595175 | 0.607318 | 0.647341 | 0.501198 | NaN | 0.676205 | NaN | NaN | NaN | NaN | 0.594940 |
| 13 | bin_37_resp | 0.536748 | 0.564127 | 0.576054 | 0.245511 | NaN | 0.693542 | NaN | NaN | NaN | NaN | 0.485497 |
| 14 | 37_resp | 0.616687 | 0.619012 | 0.505742 | 0.587418 | NaN | 0.646111 | NaN | NaN | NaN | NaN | 0.618857 |
| 15 | num_37_exp | 0.522404 | 0.493465 | 0.543739 | NaN | NaN | NaN | 0.527926 | 0.453576 | NaN | NaN | 0.484222 |
| 16 | bin_37_exp | 0.549329 | 0.529806 | 0.572498 | NaN | NaN | NaN | 0.536601 | 0.522808 | NaN | NaN | 0.528493 |
| 17 | 37_exp | 0.542120 | 0.516576 | 0.503189 | NaN | NaN | NaN | 0.539687 | 0.491022 | NaN | NaN | 0.511081 |
| 18 | num_37_multi | 0.396209 | 0.517055 | 0.610488 | 0.470327 | 0.0 | 0.613356 | NaN | NaN | NaN | NaN | 0.466507 |
| 19 | bin_37_multi | 0.357564 | 0.480236 | 0.56154 | 0.236620 | 0.0 | 0.625000 | NaN | NaN | NaN | NaN | 0.378990 |
| 20 | 37_multi | 0.410782 | 0.527574 | None | 0.545080 | 0.0 | 0.590033 | NaN | NaN | NaN | NaN | 0.485564 |
| 21 | num_37_lithium2y | 0.549715 | 0.743704 | 0.602478 | NaN | NaN | NaN | NaN | NaN | 0.261029 | 0.844967 | 0.728813 |
| 22 | bin_37_lithium2y | 0.500401 | 0.801243 | 0.58406 | NaN | NaN | NaN | NaN | NaN | 0.001601 | 0.889636 | 0.712992 |
| 23 | 37_lithium2y | 0.582458 | 0.687600 | 0.508445 | NaN | NaN | NaN | NaN | NaN | 0.341840 | 0.795193 | 0.705014 |
| 24 | num_13_resp | 0.585767 | 0.599208 | 0.637975 | 0.479294 | NaN | 0.674230 | NaN | NaN | NaN | NaN | 0.583711 |
| 25 | bin_13_resp | 0.500000 | 0.535647 | 0.51305 | 0.000000 | NaN | 0.697617 | NaN | NaN | NaN | NaN | 0.373677 |
| 26 | 13_resp | 0.583261 | 0.596379 | 0.495707 | 0.478811 | NaN | 0.670668 | NaN | NaN | NaN | NaN | 0.581579 |
| 27 | num_13_exp | 0.522436 | 0.472107 | 0.540459 | NaN | NaN | NaN | 0.558164 | 0.344418 | NaN | NaN | 0.432519 |
| 28 | bin_13_exp | 0.502532 | 0.418075 | 0.560697 | NaN | NaN | NaN | 0.582122 | 0.041984 | NaN | NaN | 0.264616 |
| 29 | 13_exp | 0.539043 | 0.493465 | 0.476353 | NaN | NaN | NaN | 0.565134 | 0.393511 | NaN | NaN | 0.464250 |
| 30 | num_13_multi | 0.390026 | 0.510201 | 0.603394 | 0.451128 | 0.0 | 0.610942 | NaN | NaN | NaN | NaN | 0.457781 |
| 31 | bin_13_multi | 0.333333 | 0.456009 | 0.510845 | 0.000000 | 0.0 | 0.626382 | NaN | NaN | NaN | NaN | 0.285636 |
| 32 | 13_multi | 0.388248 | 0.507651 | None | 0.450124 | 0.0 | 0.607827 | NaN | NaN | NaN | NaN | 0.455962 |
| 33 | num_13_lithium2y | 0.532543 | 0.762512 | 0.597761 | NaN | NaN | NaN | NaN | NaN | 0.201501 | 0.860513 | 0.729425 |
| 34 | bin_13_lithium2y | 0.500000 | 0.801084 | 0.541757 | NaN | NaN | NaN | NaN | NaN | 0.000000 | 0.889558 | 0.712610 |
| 35 | 13_lithium2y | 0.535947 | 0.764106 | 0.516645 | NaN | NaN | NaN | NaN | NaN | 0.209402 | 0.861371 | 0.731684 |
import lightgbm as lgb
from sklearn.model_selection import cross_val_score
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, space_eval
from hyperopt.pyll import scope
import time
param_hyperopt= {
'learning_rate': hp.loguniform('learning_rate', np.log(0.01), np.log(1)),
'max_depth': scope.int(hp.quniform('max_depth', 5, 15, 1)),
'n_estimators': scope.int(hp.quniform('n_estimators', 5, 35, 1)),
'num_leaves': scope.int(hp.quniform('num_leaves', 5, 50, 1)),
'boosting_type': hp.choice('boosting_type', ['gbdt', 'dart']),
'colsample_bytree': hp.uniform('colsample_by_tree', 0.6, 1.0),
'reg_lambda': hp.uniform('reg_lambda', 0.0, 1.0),
}
def hyperopt_lgbm(param_space, X_train, y_train, X_test, y_test, num_eval, metric):
start = time.time()
def objective_function(params):
clf = lgb.LGBMClassifier(**params)
score = cross_val_score(clf, X_train, y_train, cv=5, scoring=metric).mean()
return {'loss': -score, 'status': STATUS_OK}
trials = Trials()
best_param = fmin(objective_function,
param_space,
algo=tpe.suggest,
max_evals=num_eval,
trials=trials)
loss = [x['result']['loss'] for x in trials.trials]
best_param_dict = space_eval(param_space, best_param)
clf_best = lgb.LGBMClassifier(**best_param_dict)
clf_best.fit(X_train, y_train)
print("")
print("##### Results")
print("Score best parameters: ", min(loss)*-1)
print("Best parameters: ", best_param_dict)
print("Test Score: ", clf_best.score(X_test, y_test))
print("Time elapsed: ", time.time() - start)
print("Parameter combinations evaluated: ", num_eval)
return clf_best, best_param_dict
results_hyperopt_lgbm = dict()
best_params_dict_lgbm = dict()
for target in ['lithium2y']:
for feature_number in ['13']:
feature_set = feature_number + '_' + target
print('\n' + 10*'_' + feature_set + 10*'_')
for metric in ['balanced_accuracy']:
print('\n\t----> metric:', metric)
results_hyperopt_lgbm[feature_set + '_' + metric], best_params_dict_lgbm[feature_set + '_' + metric] = hyperopt_lgbm(param_hyperopt,
x_train[feature_set],
y_train[feature_set],
x_test[feature_set],
y_test[feature_set],
1000, metric)
__________13_lithium2y__________
----> metric: balanced_accuracy
100%|██████████| 1000/1000 [22:38<00:00, 1.36s/trial, best loss: -0.5348214119885094]
##### Results
Score best parameters: 0.5348214119885094
Best parameters: {'boosting_type': 'gbdt', 'colsample_bytree': 0.600831897129334, 'learning_rate': 0.9933016655780533, 'max_depth': 12, 'n_estimators': 34, 'num_leaves': 50, 'reg_lambda': 0.6221105546939463}
Test Score: 0.7496015301243226
Time elapsed: 1359.1855659484863
Parameter combinations evaluated: 1000
all_results_hyper_lgbm = evaluate(results_hyperopt_lgbm, x_test, y_test)
________________________________________13_lithium2y_balanced_accuracy________________________________________ Balanced Accuracy: 0.5271952832450744 ROC_AUC score: 0.5663205179221893
___________________________________________________________________________________
display(all_results_hyper_lgbm)
| features | balanced accuracy | accuracy | roc_auc | f1 (response) | f1 (equivocal) | f1 (no response) | f1 (lithium) | f1 (olanzapine) | f1 (lithium > 2y) | f1 (other) | f1 (weighted avg) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 13_lithium2y_balanced_accuracy | 0.527195 | 0.749602 | 0.566321 | NaN | NaN | NaN | NaN | NaN | 0.200509 | 0.851554 | 0.722051 |
from sklearn.model_selection import cross_val_score
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, space_eval
from hyperopt.pyll import scope
import time
param_hyperopt= {
'max_depth': scope.int(hp.quniform('max_depth', 5, 15, 1)),
'n_estimators': scope.int(hp.quniform('n_estimators', 5, 35, 1)),
'min_samples_split':hp.uniform('min_samples_split',0,1),
'min_samples_leaf':hp.randint('min_samples_leaf',1,10),
'criterion':hp.choice('criterion', ['gini','entropy']),
'max_features':hp.choice('max_features',['sqrt', 'log2'])
}
def hyperopt_rf(param_space, X_train, y_train, X_test, y_test, num_eval, metric):
start = time.time()
def objective_function(params):
clf = RandomForestClassifier(**params, random_state=42)
clf.fit(X_train, y_train)
score = cross_val_score(clf, X_train, y_train, cv=5, scoring=metric).mean()
return {'loss': -score, 'status': STATUS_OK}
trials = Trials()
best_param = fmin(objective_function,
param_space,
algo=tpe.suggest,
max_evals=num_eval,
trials=trials)
best_param_dict = space_eval(param_space, best_param)
loss = [x['result']['loss'] for x in trials.trials]
clf_best = RandomForestClassifier(**best_param_dict)
clf_best.fit(X_train, y_train)
print("##### Results")
print("Score best parameters: ", min(loss)*-1)
print("Best parameters: ", best_param)
print("Test Score: ", clf_best.score(X_test, y_test))
print("Time elapsed: ", time.time() - start)
print("Parameter combinations evaluated: ", num_eval)
#results(clf_best, X_test, y_test)
return clf_best, best_param_dict
results_hyperopt = dict()
best_params_dict = dict()
for target in ['lithium2y']:
for feature_number in ['13']:
feature_set = feature_number + '_' + target
print('\n' + 10*'_' + feature_set + 10*'_')
for metric in ['balanced_accuracy']:
print('\n\t----> metric:', metric)
results_hyperopt[feature_set + '_' + metric], best_params_dict[feature_set + '_' + metric] = hyperopt_rf(param_hyperopt,
x_train[feature_set],
y_train[feature_set],
x_test[feature_set],
y_test[feature_set],
1000, metric)
__________13_lithium2y__________
----> metric: balanced_accuracy
100%|██████████| 1000/1000 [19:17<00:00, 1.16s/trial, best loss: -0.5106134589427425]
##### Results
Score best parameters: 0.5106134589427425
Best parameters: {'criterion': 0, 'max_depth': 14.0, 'max_features': 0, 'min_samples_leaf': 2, 'min_samples_split': 0.00035845671793269197, 'n_estimators': 16.0}
Test Score: 0.7975773031558814
Time elapsed: 1158.2907378673553
Parameter combinations evaluated: 1000
all_results_hyper_rf = evaluate(results_hyperopt, x_test, y_test)
________________________________________13_lithium2y_balanced_accuracy________________________________________ Balanced Accuracy: 0.5029310725254319 ROC_AUC score: 0.589657897522626
___________________________________________________________________________________
display(all_results_hyper_rf)
| features | balanced accuracy | accuracy | roc_auc | f1 (response) | f1 (equivocal) | f1 (no response) | f1 (lithium) | f1 (olanzapine) | f1 (lithium > 2y) | f1 (other) | f1 (weighted avg) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 13_lithium2y_balanced_accuracy | 0.502931 | 0.797577 | 0.589658 | NaN | NaN | NaN | NaN | NaN | 0.026074 | 0.887051 | 0.715789 |
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from umap import UMAP
# Classification with a linear SVM
svc = LinearSVC(dual=False, random_state=123)
params_grid = {"C": [10 ** k for k in range(-3, 4)]}
clf = GridSearchCV(svc, params_grid, scoring='balanced_accuracy')
feature_set = '13_lithium2y'
clf.fit(x_train[feature_set], y_train[feature_set])
print(
"Balanced accuracy on the test set with raw data: {:.3f}".format(clf.score(x_test[feature_set], y_test[feature_set]))
)
# Transformation with UMAP followed by classification with a linear SVM
umap = UMAP(random_state=456)
pipeline = Pipeline([("umap", umap), ("svc", svc)])
params_grid_pipeline = {
"umap__n_neighbors": [5, 20],
"umap__n_components": [15, 25, 50],
"svc__C": [10 ** k for k in range(-3, 4)],
}
clf_pipeline = GridSearchCV(pipeline, params_grid_pipeline, scoring='balanced_accuracy')
clf_pipeline.fit(x_train[feature_set], y_train[feature_set])
print(
"Balanced accuracy on the test set with UMAP transformation: {:.3f}".format(
clf_pipeline.score(x_test[feature_set], y_test[feature_set])
)
)
Balanced accuracy on the test set with raw data: 0.504
OMP: Info #270: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Balanced accuracy on the test set with UMAP transformation: 0.513
print("Balanced accuracy on the test set with raw data: {:.6f}".format(clf.score(x_test[feature_set], y_test[feature_set])),
"\nBalanced accuracy on the test set with UMAP transformation: {:.6f}".format(clf_pipeline.score(x_test[feature_set], y_test[feature_set])))
Balanced accuracy on the test set with raw data: 0.504214 Balanced accuracy on the test set with UMAP transformation: 0.513193